Skip to content

Commit e7b310b

Browse files
authored
Float8Tensor per row quantization pass bias to fbgemm kernel (#2884)
Summary: Previously bias is not passed to fbgemm kernel for float8 per row quant, this PR adds it. Difference is we should have a faster float8 per row quantized kernel, without changing numerics or other things. Test Plan: ``` python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2884, branch: jerryzh168/stack/60
1 parent 2dacd7f commit e7b310b

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,9 @@ def test_moe_weight_reshape_ops(self):
418418
# https://github.com/pytorch/ao/issues/2649
419419
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
420420
def test_expected_gpu_kernel_fbgemm(self):
421-
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels"""
421+
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels
422+
and the bias add happens in the gemm kernel for per row quantization
423+
"""
422424
torch.compiler.reset()
423425

424426
M, K, N = 128, 256, 512
@@ -434,10 +436,15 @@ def test_expected_gpu_kernel_fbgemm(self):
434436
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
435437
out, code = run_and_get_code(m, x)
436438

437-
# check at least one occurrence of the quantize op and rowwise gemm op
439+
# 1. check at least one occurrence of the quantize op and rowwise gemm op
440+
# 2. check that there are no additional kernels like `triton_poi_fused_add_0`
441+
# are run, since the bias add should happen in the `f8f8bf16_rowwise.default`
442+
# op instead of separately
438443
FileCheck().check_count(
439-
"torch.ops.triton.quantize_fp8_row.default", 1
440-
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0])
444+
"torch.ops.triton.quantize_fp8_row.default(", 1
445+
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default(", 1).check_not(
446+
".run("
447+
).run(code[0])
441448

442449

443450
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ def _(func, types, args, kwargs):
285285
"Expected fbgemm_gpu_genai package to be installed"
286286
)
287287
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
288+
mm_config = weight_tensor.mm_config
289+
assert mm_config is not None
288290

289291
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
290292
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
@@ -300,6 +302,8 @@ def _(func, types, args, kwargs):
300302
wq,
301303
x_scale,
302304
w_scale,
305+
bias=bias,
306+
use_fast_accum=mm_config.use_fast_accum,
303307
).reshape(out_shape)
304308
else:
305309
assert _is_tensorwise_scaled(weight_tensor)
@@ -308,9 +312,10 @@ def _(func, types, args, kwargs):
308312
xq,
309313
wq,
310314
x_scale * w_scale,
315+
use_fast_accum=mm_config.use_fast_accum,
311316
).reshape(out_shape)
312-
if bias is not None:
313-
res = res + bias
317+
if bias is not None:
318+
res = res + bias
314319
return res
315320
else:
316321
assert kernel_choice == "torch"

0 commit comments

Comments
 (0)