File size: 12,298 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# Copyright (c) Meta Platforms, Inc. and affiliates.

import re
import warnings
from typing import Callable

import torch
from torch.distributed.tensor import DTensor, Partial, Shard

# avoid division by zero when calculating scale
EPS = 1e-12


def get_splitk(t):
    # When tensor parallelism splits the operands along the reduction dim, it's
    # more natural (and efficient, and accurate) to do sub-row-wise scaling, so
    # that each rank can compute its own scales independently.
    if isinstance(t, DTensor) and t.placements == (Shard(dim=1),):
        return t.device_mesh.size()
    else:
        return 1


def mul_tiled(a, *bs):
    # If b is m x n, divide a into m x n chunks and multiply each by an element of b
    for b in bs:
        a = a.unflatten(0, (b.shape[0], -1)).unflatten(-1, (b.shape[-1], -1))
        a = a * b[:, None, :, None]
        a = a.flatten(end_dim=1).flatten(start_dim=-2)
    return a


def apply_to_partial(fn, t, *args, **kwargs):
    # With tensor parallelism, _scaled_mm returns a "partial" result, but we do
    # manual (post-)scaling which we want to apply to each partial term
    # separately, thus we do this hack to "unpack" the DTensors.
    if isinstance(t, DTensor) and t.placements == (Partial(),):
        return torch.distributed.tensor.experimental.local_map(fn, [*t.placements])(t, *args, **kwargs)
    else:
        return fn(t, *args, **kwargs)


def scale(t, amax_t):
    max_v = torch.finfo(torch.float8_e4m3fn).max
    scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
    t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(torch.float8_e4m3fn)
    return t_fp8, scale_t


def matmul(first, amax_first, second_t, amax_second_t, bias, use_fast_accum):
    first_fp8, scale_first = scale(first, amax_first)
    second_t_fp8, scale_second_t = scale(second_t, amax_second_t)

    # PyTorch's row-wise scaled matmul kernel is based on CUTLASS and is quite
    # slow when fast_accum is disabled. Hence we fall back to an "unscaled"
    # matmul, which uses cuBLAS, and apply the scale manually afterwards.
    post_scales = []
    post_bias = None
    if not use_fast_accum:
        post_scales = [scale_first, scale_second_t.t()]
        scale_first = scale_first.new_ones((1, 1))
        scale_second_t = scale_second_t.t().new_ones((1, 1))
        post_bias, bias = bias, None

    res = torch._scaled_mm(
        first_fp8,
        second_t_fp8.t(),
        scale_a=scale_first,
        scale_b=scale_second_t.t(),
        bias=bias,
        out_dtype=torch.bfloat16,
        use_fast_accum=use_fast_accum,
    )

    res = apply_to_partial(mul_tiled, res, *post_scales).to(torch.bfloat16)
    if post_bias is not None:
        res += post_bias

    return res


@torch.compiler.allow_in_graph
class Fp8LinearFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b_t, bias):
        amax_a = a.abs().unflatten(-1, (get_splitk(a), -1)).amax(dim=-1)
        amax_b_t = b_t.abs().unflatten(-1, (get_splitk(b_t), -1)).amax(dim=-1)
        out = matmul(a, amax_a, b_t, amax_b_t, bias, use_fast_accum=True)

        ctx.a_requires_grad = a.requires_grad
        ctx.b_requires_grad = b_t.requires_grad
        ctx.bias_requires_grad = bias.requires_grad if bias is not None else False

        ctx.save_for_backward(a, b_t, amax_b_t)

        return out

    @staticmethod
    def backward(ctx, grad_out):
        a, b_t, amax_b_t = ctx.saved_tensors

        # Workaround for https://github.com/pytorch/pytorch/issues/141881.
        # The partitioner would pre-compute the transposed scaling of the weight
        # in the forward (as it's most efficient, but it actually uses too much
        # memory). We prevent that by making the scaling depend on the gradient
        # in a way that has no effect and will be optimized away later.
        # Care is needed to support tensor parallelism and circumvent bugs.
        b_t = b_t + grad_out[:1, :, None].squeeze(0) * 0

        if ctx.a_requires_grad:
            b = b_t.t().contiguous()
            amax_grad_out = (
                grad_out.abs().unflatten(-1, (get_splitk(grad_out), -1)).amax(dim=-1)
            )
            amax_b = amax_b_t.t().unflatten(-1, (get_splitk(b), -1)).amax(dim=-1)
            amax_b = amax_b.repeat_interleave(
                b.shape[0] // amax_b.shape[0], dim=0, output_size=b.shape[0]
            )
            grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None, use_fast_accum=False)
        else:
            grad_a = None
        if ctx.b_requires_grad:
            grad_b = grad_out.t() @ a
        else:
            grad_b = None
        if ctx.bias_requires_grad:
            grad_bias = grad_out.sum(dim=0)
        else:
            grad_bias = None

        return grad_a, grad_b, grad_bias


class Fp8Linear(torch.nn.Linear):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
        out = out.unflatten(0, input.shape[:-1])
        return out


def named_replace(fn: Callable[[torch.nn.Module, str], torch.nn.Module], module: torch.nn.Module, name="") -> torch.nn.Module:
    for child_name, child_module in list(module.named_children()):
        full_name = f"{name}.{child_name}" if name else child_name
        new_child_module = named_replace(fn, child_module, full_name)
        setattr(module, child_name, new_child_module)
    module = fn(module, name)
    return module


def convert_linears_to_fp8(root_module: torch.nn.Module, recipe: str, filter: str) -> torch.nn.Module:
    if recipe not in ["rowwise"]:
        raise RuntimeError(f"Unknown float8 recipe {recipe!r}")

    if recipe == "rowwise" and torch.__version__ < "2.5":
        # We need https://github.com/pytorch/pytorch/pull/134781.
        warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")

    # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
    # reduction kernel and a "persistent" reduction kernel. Since fp8 has some
    # multi-pass steps (e.g., first get amax, then scale), persistent kernels
    # should perform better.
    torch._inductor.config.triton.multi_kernel = 1

    filter_re = re.compile(filter)
    def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
        if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
            return module
        if type(module) == torch.nn.Linear:
            if recipe == "rowwise":
                new_module = Fp8Linear(
                    in_features=module.in_features,
                    out_features=module.out_features,
                    bias=module.bias is not None,
                    dtype=module.weight.dtype,
                    device=module.weight.device,
                )
                new_module.weight = module.weight
                new_module.bias = module.bias
            else:
                assert False, recipe
        else:
            assert False, str(type(module))
        return new_module
    out = named_replace(replace, root_module)

    # Force re-compile everything
    torch._dynamo.reset_code_caches()
    from torch._inductor.cudagraph_trees import reset_cudagraph_trees
    reset_cudagraph_trees()

    return out


# We need some upstream PyTorch fixes which are only present in v2.7+ or in
# nightlies starting from January 7, 2025. For earlier versions, we copy-pasted
# the relevant pieces of code below.
if torch.__version__ < "2.7.0.dev20250107":
    from torch.distributed.device_mesh import DeviceMesh
    from torch.distributed.tensor._dtensor_spec import DTensorSpec
    from torch.distributed.tensor._op_schema import (
        OpSchema,
        OpStrategy,
        PlacementStrategy,
        RuntimeSchemaInfo,
    )
    from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies
    from torch.distributed.tensor._ops._math_ops import (
        _infer_reduction_dims,
        common_reduction_strategy,
    )
    from torch.distributed.tensor._ops.utils import (
        generate_redistribute_costs,
        is_tensor_shardable,
        prod,
        register_op_strategy,
    )
    from torch.distributed.tensor.placement_types import Replicate

    # Cherry-pick of https://github.com/pytorch/pytorch/pull/143747

    LINEAR_REDUCTION_OP_MAP = {
        torch.ops.aten.amax.default: "max",
        torch.ops.aten.amin.default: "min",
    }

    @register_op_strategy(
        list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1)
    )
    def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
        args_schema = op_schema.args_schema
        input_strategy = args_schema[0]
        assert isinstance(input_strategy, OpStrategy)
        dims = None
        if len(op_schema.args_schema) > 1:
            dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim)

        reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims

        keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
        reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op]
        return common_reduction_strategy(
            mesh,
            input_strategy,
            reduce_dims,
            keep_dim=keep_dim,
            reduction_linear=True,
            reduction_op=reduction_op,
        )

    # Cherry-pick of https://github.com/pytorch/pytorch/pull/143760

    def _mm_like_strategy(
        mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
    ) -> OpStrategy:
        (
            self_strategy,
            mat2_strategy,
            scale_self_strategy,
            scale_mat2_strategy,
            bias_strategy,
            scale_result_strategy,
            *_,
        ) = op_schema.args_schema
        assert isinstance(self_strategy, OpStrategy)
        assert isinstance(mat2_strategy, OpStrategy)
        assert isinstance(scale_self_strategy, OpStrategy)
        assert isinstance(scale_mat2_strategy, OpStrategy)
        assert bias_strategy is None
        assert scale_result_strategy is None
        # generate all possible strategies for mm
        mm_strategy = gen_einsum_strategies(mm_equation, mesh)
        assert isinstance(mm_strategy, OpStrategy)
        # filter out invalid strategies and associate costs
        strategies = mm_strategy.strategies
        filtered_strategies = []
        for strtg in strategies:
            assert isinstance(strtg, PlacementStrategy)
            assert strtg.input_specs is not None
            self_spec = strtg.input_specs[0]
            mat2_spec = strtg.input_specs[1]
            assert isinstance(self_spec, DTensorSpec)
            assert isinstance(mat2_spec, DTensorSpec)
            scale_self_spec = (
                DTensorSpec(self_spec.mesh, (Replicate(),))
                if prod(scale_self_strategy.shape) == 1
                else self_spec
            )
            scale_mat2_spec = (
                DTensorSpec(mat2_spec.mesh, (Replicate(),))
                if prod(scale_mat2_strategy.shape) == 1
                else mat2_spec
            )
            strtg.input_specs.extend([scale_self_spec, scale_mat2_spec])
            if (
                is_tensor_shardable(self_strategy.shape, self_spec)
                and is_tensor_shardable(mat2_strategy.shape, mat2_spec)
                and is_tensor_shardable(scale_self_strategy.shape, scale_self_spec)
                and is_tensor_shardable(scale_mat2_strategy.shape, scale_mat2_spec)
            ):
                redistribute_cost = [
                    generate_redistribute_costs(self_strategy, self_spec),
                    generate_redistribute_costs(mat2_strategy, mat2_spec),
                    generate_redistribute_costs(scale_self_strategy, scale_self_spec),
                    generate_redistribute_costs(scale_mat2_strategy, scale_mat2_spec),
                ]
                strtg.redistribute_cost = redistribute_cost
                filtered_strategies.append(strtg)

        mm_strategy.strategies = filtered_strategies

        return mm_strategy

    @register_op_strategy(torch.ops.aten._scaled_mm.default)
    def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
        return _mm_like_strategy("mk,kn->mn", mesh, op_schema)