| |
|
|
| import re |
| import warnings |
| from typing import Callable |
|
|
| import torch |
| from torch.distributed.tensor import DTensor, Partial, Shard |
|
|
| |
| EPS = 1e-12 |
|
|
|
|
| def get_splitk(t): |
| |
| |
| |
| if isinstance(t, DTensor) and t.placements == (Shard(dim=1),): |
| return t.device_mesh.size() |
| else: |
| return 1 |
|
|
|
|
| def mul_tiled(a, *bs): |
| |
| for b in bs: |
| a = a.unflatten(0, (b.shape[0], -1)).unflatten(-1, (b.shape[-1], -1)) |
| a = a * b[:, None, :, None] |
| a = a.flatten(end_dim=1).flatten(start_dim=-2) |
| return a |
|
|
|
|
| def apply_to_partial(fn, t, *args, **kwargs): |
| |
| |
| |
| if isinstance(t, DTensor) and t.placements == (Partial(),): |
| return torch.distributed.tensor.experimental.local_map(fn, [*t.placements])(t, *args, **kwargs) |
| else: |
| return fn(t, *args, **kwargs) |
|
|
|
|
| def scale(t, amax_t): |
| max_v = torch.finfo(torch.float8_e4m3fn).max |
| scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v |
| t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(torch.float8_e4m3fn) |
| return t_fp8, scale_t |
|
|
|
|
| def matmul(first, amax_first, second_t, amax_second_t, bias, use_fast_accum): |
| first_fp8, scale_first = scale(first, amax_first) |
| second_t_fp8, scale_second_t = scale(second_t, amax_second_t) |
|
|
| |
| |
| |
| post_scales = [] |
| post_bias = None |
| if not use_fast_accum: |
| post_scales = [scale_first, scale_second_t.t()] |
| scale_first = scale_first.new_ones((1, 1)) |
| scale_second_t = scale_second_t.t().new_ones((1, 1)) |
| post_bias, bias = bias, None |
|
|
| res = torch._scaled_mm( |
| first_fp8, |
| second_t_fp8.t(), |
| scale_a=scale_first, |
| scale_b=scale_second_t.t(), |
| bias=bias, |
| out_dtype=torch.bfloat16, |
| use_fast_accum=use_fast_accum, |
| ) |
|
|
| res = apply_to_partial(mul_tiled, res, *post_scales).to(torch.bfloat16) |
| if post_bias is not None: |
| res += post_bias |
|
|
| return res |
|
|
|
|
| @torch.compiler.allow_in_graph |
| class Fp8LinearFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, a, b_t, bias): |
| amax_a = a.abs().unflatten(-1, (get_splitk(a), -1)).amax(dim=-1) |
| amax_b_t = b_t.abs().unflatten(-1, (get_splitk(b_t), -1)).amax(dim=-1) |
| out = matmul(a, amax_a, b_t, amax_b_t, bias, use_fast_accum=True) |
|
|
| ctx.a_requires_grad = a.requires_grad |
| ctx.b_requires_grad = b_t.requires_grad |
| ctx.bias_requires_grad = bias.requires_grad if bias is not None else False |
|
|
| ctx.save_for_backward(a, b_t, amax_b_t) |
|
|
| return out |
|
|
| @staticmethod |
| def backward(ctx, grad_out): |
| a, b_t, amax_b_t = ctx.saved_tensors |
|
|
| |
| |
| |
| |
| |
| |
| b_t = b_t + grad_out[:1, :, None].squeeze(0) * 0 |
|
|
| if ctx.a_requires_grad: |
| b = b_t.t().contiguous() |
| amax_grad_out = ( |
| grad_out.abs().unflatten(-1, (get_splitk(grad_out), -1)).amax(dim=-1) |
| ) |
| amax_b = amax_b_t.t().unflatten(-1, (get_splitk(b), -1)).amax(dim=-1) |
| amax_b = amax_b.repeat_interleave( |
| b.shape[0] // amax_b.shape[0], dim=0, output_size=b.shape[0] |
| ) |
| grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None, use_fast_accum=False) |
| else: |
| grad_a = None |
| if ctx.b_requires_grad: |
| grad_b = grad_out.t() @ a |
| else: |
| grad_b = None |
| if ctx.bias_requires_grad: |
| grad_bias = grad_out.sum(dim=0) |
| else: |
| grad_bias = None |
|
|
| return grad_a, grad_b, grad_bias |
|
|
|
|
| class Fp8Linear(torch.nn.Linear): |
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) |
| out = out.unflatten(0, input.shape[:-1]) |
| return out |
|
|
|
|
| def named_replace(fn: Callable[[torch.nn.Module, str], torch.nn.Module], module: torch.nn.Module, name="") -> torch.nn.Module: |
| for child_name, child_module in list(module.named_children()): |
| full_name = f"{name}.{child_name}" if name else child_name |
| new_child_module = named_replace(fn, child_module, full_name) |
| setattr(module, child_name, new_child_module) |
| module = fn(module, name) |
| return module |
|
|
|
|
| def convert_linears_to_fp8(root_module: torch.nn.Module, recipe: str, filter: str) -> torch.nn.Module: |
| if recipe not in ["rowwise"]: |
| raise RuntimeError(f"Unknown float8 recipe {recipe!r}") |
|
|
| if recipe == "rowwise" and torch.__version__ < "2.5": |
| |
| warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0") |
|
|
| |
| |
| |
| |
| torch._inductor.config.triton.multi_kernel = 1 |
|
|
| filter_re = re.compile(filter) |
| def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: |
| if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): |
| return module |
| if type(module) == torch.nn.Linear: |
| if recipe == "rowwise": |
| new_module = Fp8Linear( |
| in_features=module.in_features, |
| out_features=module.out_features, |
| bias=module.bias is not None, |
| dtype=module.weight.dtype, |
| device=module.weight.device, |
| ) |
| new_module.weight = module.weight |
| new_module.bias = module.bias |
| else: |
| assert False, recipe |
| else: |
| assert False, str(type(module)) |
| return new_module |
| out = named_replace(replace, root_module) |
|
|
| |
| torch._dynamo.reset_code_caches() |
| from torch._inductor.cudagraph_trees import reset_cudagraph_trees |
| reset_cudagraph_trees() |
|
|
| return out |
|
|
|
|
| |
| |
| |
| if torch.__version__ < "2.7.0.dev20250107": |
| from torch.distributed.device_mesh import DeviceMesh |
| from torch.distributed.tensor._dtensor_spec import DTensorSpec |
| from torch.distributed.tensor._op_schema import ( |
| OpSchema, |
| OpStrategy, |
| PlacementStrategy, |
| RuntimeSchemaInfo, |
| ) |
| from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies |
| from torch.distributed.tensor._ops._math_ops import ( |
| _infer_reduction_dims, |
| common_reduction_strategy, |
| ) |
| from torch.distributed.tensor._ops.utils import ( |
| generate_redistribute_costs, |
| is_tensor_shardable, |
| prod, |
| register_op_strategy, |
| ) |
| from torch.distributed.tensor.placement_types import Replicate |
|
|
| |
|
|
| LINEAR_REDUCTION_OP_MAP = { |
| torch.ops.aten.amax.default: "max", |
| torch.ops.aten.amin.default: "min", |
| } |
|
|
| @register_op_strategy( |
| list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1) |
| ) |
| def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: |
| args_schema = op_schema.args_schema |
| input_strategy = args_schema[0] |
| assert isinstance(input_strategy, OpStrategy) |
| dims = None |
| if len(op_schema.args_schema) > 1: |
| dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) |
|
|
| reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims |
|
|
| keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) |
| reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op] |
| return common_reduction_strategy( |
| mesh, |
| input_strategy, |
| reduce_dims, |
| keep_dim=keep_dim, |
| reduction_linear=True, |
| reduction_op=reduction_op, |
| ) |
|
|
| |
|
|
| def _mm_like_strategy( |
| mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema |
| ) -> OpStrategy: |
| ( |
| self_strategy, |
| mat2_strategy, |
| scale_self_strategy, |
| scale_mat2_strategy, |
| bias_strategy, |
| scale_result_strategy, |
| *_, |
| ) = op_schema.args_schema |
| assert isinstance(self_strategy, OpStrategy) |
| assert isinstance(mat2_strategy, OpStrategy) |
| assert isinstance(scale_self_strategy, OpStrategy) |
| assert isinstance(scale_mat2_strategy, OpStrategy) |
| assert bias_strategy is None |
| assert scale_result_strategy is None |
| |
| mm_strategy = gen_einsum_strategies(mm_equation, mesh) |
| assert isinstance(mm_strategy, OpStrategy) |
| |
| strategies = mm_strategy.strategies |
| filtered_strategies = [] |
| for strtg in strategies: |
| assert isinstance(strtg, PlacementStrategy) |
| assert strtg.input_specs is not None |
| self_spec = strtg.input_specs[0] |
| mat2_spec = strtg.input_specs[1] |
| assert isinstance(self_spec, DTensorSpec) |
| assert isinstance(mat2_spec, DTensorSpec) |
| scale_self_spec = ( |
| DTensorSpec(self_spec.mesh, (Replicate(),)) |
| if prod(scale_self_strategy.shape) == 1 |
| else self_spec |
| ) |
| scale_mat2_spec = ( |
| DTensorSpec(mat2_spec.mesh, (Replicate(),)) |
| if prod(scale_mat2_strategy.shape) == 1 |
| else mat2_spec |
| ) |
| strtg.input_specs.extend([scale_self_spec, scale_mat2_spec]) |
| if ( |
| is_tensor_shardable(self_strategy.shape, self_spec) |
| and is_tensor_shardable(mat2_strategy.shape, mat2_spec) |
| and is_tensor_shardable(scale_self_strategy.shape, scale_self_spec) |
| and is_tensor_shardable(scale_mat2_strategy.shape, scale_mat2_spec) |
| ): |
| redistribute_cost = [ |
| generate_redistribute_costs(self_strategy, self_spec), |
| generate_redistribute_costs(mat2_strategy, mat2_spec), |
| generate_redistribute_costs(scale_self_strategy, scale_self_spec), |
| generate_redistribute_costs(scale_mat2_strategy, scale_mat2_spec), |
| ] |
| strtg.redistribute_cost = redistribute_cost |
| filtered_strategies.append(strtg) |
|
|
| mm_strategy.strategies = filtered_strategies |
|
|
| return mm_strategy |
|
|
| @register_op_strategy(torch.ops.aten._scaled_mm.default) |
| def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: |
| return _mm_like_strategy("mk,kn->mn", mesh, op_schema) |
|
|