Skip to content

Commit 092ca75

Browse files
committed
address comments
1 parent 615de5e commit 092ca75

File tree

2 files changed

+151
-92
lines changed

2 files changed

+151
-92
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ def test_fp8_matmul_variants(
184184
)
185185
if isinstance(granularity, PerRow) and dtype != torch.bfloat16:
186186
return unittest.skip("per row only works with bfloat16")
187+
if kernel_preference == KernelPreference.FBGEMM and (
188+
(not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_90())
189+
):
190+
return unittest.skip(
191+
"Requires fbgemm_gpu_genai to run fbgemm kernel preference test"
192+
)
187193

188194
M, N, K = sizes
189195
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 145 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,18 @@ def _(func, types, args, kwargs):
259259
return _float8_linear_impl(input_tensor, weight_tensor, bias)
260260

261261

262-
@implements(aten.mm.default)
262+
@implements(aten.matmul.default)
263263
@implements_torch_function(torch.matmul)
264264
def _(func, types, args, kwargs):
265265
input_tensor, weight_tensor = args[0], args[1]
266-
return _float8_mm_impl(input_tensor, weight_tensor)
266+
return _float8_matmul_impl(input_tensor, weight_tensor)
267+
268+
269+
@implements(aten.mm.default)
270+
@implements_torch_function(torch.mm)
271+
def _(func, types, args, kwargs):
272+
input_tensor, weight_tensor = args[0], args[1]
273+
return _float8_matmul_impl(input_tensor, weight_tensor)
267274

268275

269276
@implements(aten.addmm_.default)
@@ -275,24 +282,110 @@ def _(func, types, args, kwargs):
275282
)
276283
assert kwargs.get("alpha", 1) == 1, "only alpha=1 is supported"
277284
assert kwargs.get("beta", 1) == 1, "only beta=1 is supported"
278-
out = _float8_mm_impl(input_tensor, weight_tensor)
285+
out = _float8_matmul_impl(input_tensor, weight_tensor)
279286
return bias_tensor.add_(out)
280287

281288

282-
def _float8_mm_impl(
283-
input_tensor: torch.Tensor,
284-
weight_tensor: torch.Tensor,
289+
def _get_matmul_kernel_choice(weight_tensor: Float8Tensor) -> str:
290+
"""
291+
Return the kernel choice for matmuls, either "fbgemm" or "torch".
292+
"""
293+
if weight_tensor.kernel_preference == KernelPreference.AUTO:
294+
kernel_choice = "torch"
295+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
296+
kernel_choice = "fbgemm"
297+
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
298+
kernel_choice = "fbgemm"
299+
else:
300+
assert weight_tensor.kernel_preference == KernelPreference.TORCH, (
301+
f"{weight_tensor.kernel_preference=} not handled"
302+
)
303+
kernel_choice = "torch"
304+
return kernel_choice
305+
306+
307+
def _call_fbgemm_f8f8bf16_matmul(
308+
input_tensor: Float8Tensor,
309+
weight_tensor: Float8Tensor,
310+
bias: Optional[torch.Tensor] = None,
311+
weight_is_already_transposed: bool = False,
285312
) -> torch.Tensor:
286-
assert isinstance(weight_tensor, Float8Tensor), (
287-
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
313+
"""
314+
Call `torch.ops.fbgemm.f8f8bf16*` ops.
315+
316+
These ops expect the weight tensor to be transposed.
317+
If `weight_is_already_transposed=True` (e.g. in the linear case),
318+
then we avoid unnecessarily double transposing the weight.
319+
"""
320+
assert _is_fbgemm_gpu_genai_available(), (
321+
"Expected fbgemm_gpu_genai package to be installed"
288322
)
289-
is_transposed = weight_tensor.qdata.stride(-2) < weight_tensor.qdata.stride(-1)
290-
# For matmul(x, w.t()), just call the linear implementation
291-
# For matmul(x, w), just dequantize for now, we can optimize later
292-
if is_transposed:
293-
return _float8_linear_impl(input_tensor, weight_tensor.t())
323+
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
324+
mm_config = weight_tensor.mm_config
325+
assert mm_config is not None
326+
327+
if not weight_is_already_transposed:
328+
weight_tensor = weight_tensor.t()
329+
330+
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
331+
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
332+
wq = weight_tensor.qdata
333+
x_scale = input_tensor.scale
334+
w_scale = weight_tensor.scale
335+
if _is_rowwise_scaled(weight_tensor):
336+
assert _is_rowwise_scaled(input_tensor), (
337+
"Input tensor must be rowwise block size"
338+
)
339+
res = torch.ops.fbgemm.f8f8bf16_rowwise(
340+
xq,
341+
wq,
342+
x_scale,
343+
w_scale,
344+
bias=bias,
345+
use_fast_accum=mm_config.use_fast_accum,
346+
).reshape(out_shape)
294347
else:
295-
return torch.matmul(input_tensor, weight_tensor.dequantize())
348+
assert _is_tensorwise_scaled(weight_tensor)
349+
assert _is_tensorwise_scaled(input_tensor)
350+
res = torch.ops.fbgemm.f8f8bf16(
351+
xq,
352+
wq,
353+
x_scale * w_scale,
354+
use_fast_accum=mm_config.use_fast_accum,
355+
).reshape(out_shape)
356+
if bias is not None:
357+
res = res + bias
358+
return res
359+
360+
361+
def _call_float8_scaled_mm(
362+
input_tensor: Float8Tensor,
363+
weight_tensor: Float8Tensor,
364+
bias: Optional[torch.Tensor] = None,
365+
) -> torch.Tensor:
366+
scaled_mm_config = weight_tensor.mm_config
367+
assert scaled_mm_config is not None
368+
assert weight_tensor.dim() == 2
369+
out_shape = (*input_tensor.shape[:-1], weight_tensor.shape[1])
370+
371+
# Extract tensor data and scales
372+
inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
373+
w_data = weight_tensor.qdata
374+
input_scale = input_tensor.scale
375+
w_scale = weight_tensor.scale
376+
377+
input_scale = preprocess_scale(input_scale, input_tensor.shape)
378+
inpt_data, w_data = preprocess_data(inpt_data, w_data, scaled_mm_config)
379+
380+
return addmm_float8_unwrapped_inference(
381+
inpt_data,
382+
input_scale,
383+
w_data,
384+
w_scale,
385+
output_dtype=input_tensor.dtype,
386+
bias=bias,
387+
use_fast_accum=scaled_mm_config.use_fast_accum,
388+
).reshape(out_shape)
296389

297390

298391
def _float8_linear_impl(
@@ -312,88 +405,17 @@ def _float8_linear_impl(
312405
)
313406

314407
if isinstance(input_tensor, Float8Tensor):
315-
kernel_choice = None
316-
317-
if weight_tensor.kernel_preference == KernelPreference.AUTO:
318-
kernel_choice = "torch"
319-
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
320-
kernel_choice = "fbgemm"
321-
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
322-
kernel_choice = "fbgemm"
323-
else:
324-
assert weight_tensor.kernel_preference == KernelPreference.TORCH, (
325-
f"{weight_tensor.kernel_preference=} not handled"
326-
)
327-
kernel_choice = "torch"
328-
408+
kernel_choice = _get_matmul_kernel_choice(weight_tensor)
329409
if kernel_choice == "fbgemm":
330-
assert _is_fbgemm_gpu_genai_available(), (
331-
"Expected fbgemm_gpu_genai package to be installed"
410+
return _call_fbgemm_f8f8bf16_matmul(
411+
input_tensor,
412+
weight_tensor,
413+
bias,
414+
weight_is_already_transposed=True,
332415
)
333-
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
334-
mm_config = weight_tensor.mm_config
335-
assert mm_config is not None
336-
337-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
338-
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
339-
wq = weight_tensor.qdata
340-
x_scale = input_tensor.scale
341-
w_scale = weight_tensor.scale
342-
if _is_rowwise_scaled(weight_tensor):
343-
assert _is_rowwise_scaled(input_tensor), (
344-
"Input tensor must be rowwise block size"
345-
)
346-
res = torch.ops.fbgemm.f8f8bf16_rowwise(
347-
xq,
348-
wq,
349-
x_scale,
350-
w_scale,
351-
bias=bias,
352-
use_fast_accum=mm_config.use_fast_accum,
353-
).reshape(out_shape)
354-
else:
355-
assert _is_tensorwise_scaled(weight_tensor)
356-
assert _is_tensorwise_scaled(input_tensor)
357-
res = torch.ops.fbgemm.f8f8bf16(
358-
xq,
359-
wq,
360-
x_scale * w_scale,
361-
use_fast_accum=mm_config.use_fast_accum,
362-
).reshape(out_shape)
363-
if bias is not None:
364-
res = res + bias
365-
return res
366416
else:
367417
assert kernel_choice == "torch"
368-
scaled_mm_config = weight_tensor.mm_config
369-
assert scaled_mm_config is not None
370-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
371-
372-
# Extract tensor data and scales
373-
inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
374-
w_data = weight_tensor.qdata
375-
input_scale = input_tensor.scale
376-
w_scale = weight_tensor.scale
377-
378-
# Handle rowwise scaling
379-
if _is_rowwise_scaled(weight_tensor):
380-
assert _is_rowwise_scaled(input_tensor), (
381-
"Input tensor must be rowwise block size"
382-
)
383-
w_scale = w_scale.transpose(-1, -2)
384-
385-
input_scale = preprocess_scale(input_scale, input_tensor.shape)
386-
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
387-
388-
return addmm_float8_unwrapped_inference(
389-
inpt_data,
390-
input_scale,
391-
w_data,
392-
w_scale,
393-
output_dtype=input_tensor.dtype,
394-
bias=bias,
395-
use_fast_accum=scaled_mm_config.use_fast_accum,
396-
).reshape(out_shape)
418+
return _call_float8_scaled_mm(input_tensor, weight_tensor.t(), bias)
397419
else:
398420
assert not isinstance(input_tensor, TorchAOBaseTensor), (
399421
"Expecting input_tensor to be unquantized"
@@ -405,6 +427,37 @@ def _float8_linear_impl(
405427
)
406428

407429

430+
def _float8_matmul_impl(
431+
input_tensor: torch.Tensor,
432+
weight_tensor: torch.Tensor,
433+
) -> torch.Tensor:
434+
assert isinstance(weight_tensor, Float8Tensor), (
435+
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
436+
)
437+
438+
act_quant_kwargs = weight_tensor.act_quant_kwargs
439+
# quantizing activation, if `act_quant_kwargs` is specified
440+
if act_quant_kwargs is not None:
441+
input_tensor = _choose_quant_func_and_quantize_tensor(
442+
input_tensor, act_quant_kwargs
443+
)
444+
445+
if isinstance(input_tensor, Float8Tensor):
446+
kernel_choice = _get_matmul_kernel_choice(weight_tensor)
447+
if kernel_choice == "fbgemm":
448+
return _call_fbgemm_f8f8bf16_matmul(input_tensor, weight_tensor)
449+
else:
450+
assert kernel_choice == "torch"
451+
return _call_float8_scaled_mm(input_tensor, weight_tensor)
452+
else:
453+
assert not isinstance(input_tensor, TorchAOBaseTensor), (
454+
"Expecting input_tensor to be unquantized"
455+
)
456+
# when input is not `Float8Tensor`, we expect that it is not quantized
457+
# so this is float8 weight only quantization
458+
return torch.matmul(input_tensor, weight_tensor.dequantize())
459+
460+
408461
@implements_torch_function(torch.bmm)
409462
def _(func, types, args, kwargs):
410463
input_tensor, weight_tensor = (

0 commit comments

Comments
 (0)