Skip to content

Commit 060b217

Browse files
committed
address comments
1 parent 615de5e commit 060b217

File tree

2 files changed

+144
-92
lines changed

2 files changed

+144
-92
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ def test_fp8_matmul_variants(
184184
)
185185
if isinstance(granularity, PerRow) and dtype != torch.bfloat16:
186186
return unittest.skip("per row only works with bfloat16")
187+
if kernel_preference == KernelPreference.FBGEMM and (
188+
(not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_90())
189+
):
190+
return unittest.skip(
191+
"Requires fbgemm_gpu_genai to run fbgemm kernel preference test"
192+
)
187193

188194
M, N, K = sizes
189195
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 138 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,11 @@ def _(func, types, args, kwargs):
259259
return _float8_linear_impl(input_tensor, weight_tensor, bias)
260260

261261

262-
@implements(aten.mm.default)
262+
@implements(aten.matmul.default)
263263
@implements_torch_function(torch.matmul)
264264
def _(func, types, args, kwargs):
265265
input_tensor, weight_tensor = args[0], args[1]
266-
return _float8_mm_impl(input_tensor, weight_tensor)
266+
return _float8_matmul_impl(input_tensor, weight_tensor)
267267

268268

269269
@implements(aten.addmm_.default)
@@ -275,24 +275,110 @@ def _(func, types, args, kwargs):
275275
)
276276
assert kwargs.get("alpha", 1) == 1, "only alpha=1 is supported"
277277
assert kwargs.get("beta", 1) == 1, "only beta=1 is supported"
278-
out = _float8_mm_impl(input_tensor, weight_tensor)
278+
out = _float8_matmul_impl(input_tensor, weight_tensor)
279279
return bias_tensor.add_(out)
280280

281281

282-
def _float8_mm_impl(
283-
input_tensor: torch.Tensor,
284-
weight_tensor: torch.Tensor,
282+
def _get_matmul_kernel_choice(weight_tensor: Float8Tensor) -> str:
283+
"""
284+
Return the kernel choice for matmuls, either "fbgemm" or "torch".
285+
"""
286+
if weight_tensor.kernel_preference == KernelPreference.AUTO:
287+
kernel_choice = "torch"
288+
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
289+
kernel_choice = "fbgemm"
290+
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
291+
kernel_choice = "fbgemm"
292+
else:
293+
assert weight_tensor.kernel_preference == KernelPreference.TORCH, (
294+
f"{weight_tensor.kernel_preference=} not handled"
295+
)
296+
kernel_choice = "torch"
297+
return kernel_choice
298+
299+
300+
def _call_fbgemm_f8f8bf16_matmul(
301+
input_tensor: Float8Tensor,
302+
weight_tensor: Float8Tensor,
303+
bias: Optional[torch.Tensor] = None,
304+
weight_is_already_transposed: bool = False,
285305
) -> torch.Tensor:
286-
assert isinstance(weight_tensor, Float8Tensor), (
287-
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
306+
"""
307+
Call `torch.ops.fbgemm.f8f8bf16*` ops.
308+
309+
These ops expect the weight tensor to be transposed.
310+
If `weight_is_already_transposed=True` (e.g. in the linear case),
311+
then we avoid unnecessarily double transposing the weight.
312+
"""
313+
assert _is_fbgemm_gpu_genai_available(), (
314+
"Expected fbgemm_gpu_genai package to be installed"
288315
)
289-
is_transposed = weight_tensor.qdata.stride(-2) < weight_tensor.qdata.stride(-1)
290-
# For matmul(x, w.t()), just call the linear implementation
291-
# For matmul(x, w), just dequantize for now, we can optimize later
292-
if is_transposed:
293-
return _float8_linear_impl(input_tensor, weight_tensor.t())
316+
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
317+
mm_config = weight_tensor.mm_config
318+
assert mm_config is not None
319+
320+
if not weight_is_already_transposed:
321+
weight_tensor = weight_tensor.t()
322+
323+
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
324+
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
325+
wq = weight_tensor.qdata
326+
x_scale = input_tensor.scale
327+
w_scale = weight_tensor.scale
328+
if _is_rowwise_scaled(weight_tensor):
329+
assert _is_rowwise_scaled(input_tensor), (
330+
"Input tensor must be rowwise block size"
331+
)
332+
res = torch.ops.fbgemm.f8f8bf16_rowwise(
333+
xq,
334+
wq,
335+
x_scale,
336+
w_scale,
337+
bias=bias,
338+
use_fast_accum=mm_config.use_fast_accum,
339+
).reshape(out_shape)
294340
else:
295-
return torch.matmul(input_tensor, weight_tensor.dequantize())
341+
assert _is_tensorwise_scaled(weight_tensor)
342+
assert _is_tensorwise_scaled(input_tensor)
343+
res = torch.ops.fbgemm.f8f8bf16(
344+
xq,
345+
wq,
346+
x_scale * w_scale,
347+
use_fast_accum=mm_config.use_fast_accum,
348+
).reshape(out_shape)
349+
if bias is not None:
350+
res = res + bias
351+
return res
352+
353+
354+
def _call_float8_scaled_mm(
355+
input_tensor: Float8Tensor,
356+
weight_tensor: Float8Tensor,
357+
bias: Optional[torch.Tensor] = None,
358+
) -> torch.Tensor:
359+
scaled_mm_config = weight_tensor.mm_config
360+
assert scaled_mm_config is not None
361+
assert weight_tensor.dim() == 2
362+
out_shape = (*input_tensor.shape[:-1], weight_tensor.shape[1])
363+
364+
# Extract tensor data and scales
365+
inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
366+
w_data = weight_tensor.qdata
367+
input_scale = input_tensor.scale
368+
w_scale = weight_tensor.scale
369+
370+
input_scale = preprocess_scale(input_scale, input_tensor.shape)
371+
inpt_data, w_data = preprocess_data(inpt_data, w_data, scaled_mm_config)
372+
373+
return addmm_float8_unwrapped_inference(
374+
inpt_data,
375+
input_scale,
376+
w_data,
377+
w_scale,
378+
output_dtype=input_tensor.dtype,
379+
bias=bias,
380+
use_fast_accum=scaled_mm_config.use_fast_accum,
381+
).reshape(out_shape)
296382

297383

298384
def _float8_linear_impl(
@@ -312,88 +398,17 @@ def _float8_linear_impl(
312398
)
313399

314400
if isinstance(input_tensor, Float8Tensor):
315-
kernel_choice = None
316-
317-
if weight_tensor.kernel_preference == KernelPreference.AUTO:
318-
kernel_choice = "torch"
319-
if _is_fbgemm_gpu_genai_available() and is_sm_at_least_90():
320-
kernel_choice = "fbgemm"
321-
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
322-
kernel_choice = "fbgemm"
323-
else:
324-
assert weight_tensor.kernel_preference == KernelPreference.TORCH, (
325-
f"{weight_tensor.kernel_preference=} not handled"
326-
)
327-
kernel_choice = "torch"
328-
401+
kernel_choice = _get_matmul_kernel_choice(weight_tensor)
329402
if kernel_choice == "fbgemm":
330-
assert _is_fbgemm_gpu_genai_available(), (
331-
"Expected fbgemm_gpu_genai package to be installed"
403+
return _call_fbgemm_f8f8bf16_matmul(
404+
input_tensor,
405+
weight_tensor,
406+
bias,
407+
weight_is_already_transposed=True,
332408
)
333-
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
334-
mm_config = weight_tensor.mm_config
335-
assert mm_config is not None
336-
337-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
338-
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
339-
wq = weight_tensor.qdata
340-
x_scale = input_tensor.scale
341-
w_scale = weight_tensor.scale
342-
if _is_rowwise_scaled(weight_tensor):
343-
assert _is_rowwise_scaled(input_tensor), (
344-
"Input tensor must be rowwise block size"
345-
)
346-
res = torch.ops.fbgemm.f8f8bf16_rowwise(
347-
xq,
348-
wq,
349-
x_scale,
350-
w_scale,
351-
bias=bias,
352-
use_fast_accum=mm_config.use_fast_accum,
353-
).reshape(out_shape)
354-
else:
355-
assert _is_tensorwise_scaled(weight_tensor)
356-
assert _is_tensorwise_scaled(input_tensor)
357-
res = torch.ops.fbgemm.f8f8bf16(
358-
xq,
359-
wq,
360-
x_scale * w_scale,
361-
use_fast_accum=mm_config.use_fast_accum,
362-
).reshape(out_shape)
363-
if bias is not None:
364-
res = res + bias
365-
return res
366409
else:
367410
assert kernel_choice == "torch"
368-
scaled_mm_config = weight_tensor.mm_config
369-
assert scaled_mm_config is not None
370-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
371-
372-
# Extract tensor data and scales
373-
inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
374-
w_data = weight_tensor.qdata
375-
input_scale = input_tensor.scale
376-
w_scale = weight_tensor.scale
377-
378-
# Handle rowwise scaling
379-
if _is_rowwise_scaled(weight_tensor):
380-
assert _is_rowwise_scaled(input_tensor), (
381-
"Input tensor must be rowwise block size"
382-
)
383-
w_scale = w_scale.transpose(-1, -2)
384-
385-
input_scale = preprocess_scale(input_scale, input_tensor.shape)
386-
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
387-
388-
return addmm_float8_unwrapped_inference(
389-
inpt_data,
390-
input_scale,
391-
w_data,
392-
w_scale,
393-
output_dtype=input_tensor.dtype,
394-
bias=bias,
395-
use_fast_accum=scaled_mm_config.use_fast_accum,
396-
).reshape(out_shape)
411+
return _call_float8_scaled_mm(input_tensor, weight_tensor.t(), bias)
397412
else:
398413
assert not isinstance(input_tensor, TorchAOBaseTensor), (
399414
"Expecting input_tensor to be unquantized"
@@ -405,6 +420,37 @@ def _float8_linear_impl(
405420
)
406421

407422

423+
def _float8_matmul_impl(
424+
input_tensor: torch.Tensor,
425+
weight_tensor: torch.Tensor,
426+
) -> torch.Tensor:
427+
assert isinstance(weight_tensor, Float8Tensor), (
428+
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
429+
)
430+
431+
act_quant_kwargs = weight_tensor.act_quant_kwargs
432+
# quantizing activation, if `act_quant_kwargs` is specified
433+
if act_quant_kwargs is not None:
434+
input_tensor = _choose_quant_func_and_quantize_tensor(
435+
input_tensor, act_quant_kwargs
436+
)
437+
438+
if isinstance(input_tensor, Float8Tensor):
439+
kernel_choice = _get_matmul_kernel_choice(weight_tensor)
440+
if kernel_choice == "fbgemm":
441+
return _call_fbgemm_f8f8bf16_matmul(input_tensor, weight_tensor)
442+
else:
443+
assert kernel_choice == "torch"
444+
return _call_float8_scaled_mm(input_tensor, weight_tensor)
445+
else:
446+
assert not isinstance(input_tensor, TorchAOBaseTensor), (
447+
"Expecting input_tensor to be unquantized"
448+
)
449+
# when input is not `Float8Tensor`, we expect that it is not quantized
450+
# so this is float8 weight only quantization
451+
return torch.matmul(input_tensor, weight_tensor.dequantize())
452+
453+
408454
@implements_torch_function(torch.bmm)
409455
def _(func, types, args, kwargs):
410456
input_tensor, weight_tensor = (

0 commit comments

Comments
 (0)