Skip to content

Commit 4516f6e

Browse files
committed
Fix Float8Tensor quantize op kernrel preference dispatch
Summary: Previously if user specifies kernel_preference == "fbgemm", we'll use torch ops like `_choose_scale_float8` and `_quantize_affine_float8` to quantize the high precision Tensor into a float8 Tensor this PR makes sure we use fbgemm kernels when kernel_preference is "fbgemm", meaning: `torch.ops.triton.quantize_fp8_row` for per row, and `torch.ops.fbgemm.quantize_fp8_per_tensor` for per tensor (while `torch.ops.fbgemm.quantize_fp8_per_tensor` has some issues right now and we'll enable later when it's fixed) This doesn't have impact on BC, meaning old serialized model can still be loaded and run, only thing is fixing the kernel choice for fbgemm kernel preference means users who requested FBGEMM kernelpreference now actually run fbgemm quantize op instead of torch op Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_expected_gpu_kernel_fbgemm Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2883, branch: jerryzh168/stack/59
1 parent 2a53216 commit 4516f6e

File tree

3 files changed

+86
-17
lines changed

3 files changed

+86
-17
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from typing import Tuple
1111

1212
import torch
13+
from torch._inductor.utils import run_and_get_code
14+
from torch.testing import FileCheck
1315
from torch.testing._internal import common_utils
1416
from torch.testing._internal.common_utils import (
1517
run_tests,
@@ -85,6 +87,14 @@ def test_fp8_linear_variants(
8587
kernel_preference: KernelPreference,
8688
sizes: Tuple,
8789
):
90+
if (
91+
isinstance(granularity, PerTensor)
92+
and kernel_preference == KernelPreference.FBGEMM
93+
):
94+
return unittest.skip(
95+
"per tensor with fbgemm kernel preferece does not work yet"
96+
)
97+
8898
error_message = None
8999
if isinstance(granularity, PerRow):
90100
if mode == "dynamic" and dtype != torch.bfloat16:
@@ -237,7 +247,11 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
237247
other_kernel_preferences = [
238248
KernelPreference.AUTO,
239249
]
240-
if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90():
250+
if (
251+
_is_fbgemm_genai_gpu_available()
252+
and is_sm_at_least_90()
253+
and not isinstance(granularity, PerTensor)
254+
):
241255
other_kernel_preferences.append(KernelPreference.FBGEMM)
242256

243257
quantized_outputs = {}
@@ -399,6 +413,32 @@ def test_moe_weight_reshape_ops(self):
399413
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
400414
self._test_moe_weight_reshape_ops(config)
401415

416+
# TODO: we have some other tests living in https://github.com/pytorch/ao/blob/4ecc89edd7b5cfc12e6f80854c85d04c472a0eb0/test/dtypes/test_affine_quantized_float.py#L743
417+
# that should be moved here after v1 config is deprecated:
418+
# https://github.com/pytorch/ao/issues/2649
419+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
420+
def test_expected_gpu_kernel_fbgemm(self):
421+
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels"""
422+
torch.compiler.reset()
423+
424+
M, K, N = 128, 256, 512
425+
m = torch.nn.Sequential(
426+
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
427+
)
428+
config = Float8DynamicActivationFloat8WeightConfig(
429+
granularity=PerRow(),
430+
kernel_preference=KernelPreference.FBGEMM,
431+
)
432+
quantize_(m, config)
433+
m = torch.compile(m)
434+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
435+
out, code = run_and_get_code(m, x)
436+
437+
# check at least one occurrence of the quantize op and rowwise gemm op
438+
FileCheck().check_count(
439+
"torch.ops.triton.quantize_fp8_row.default", 1
440+
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default", 1).run(code[0])
441+
402442

403443
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
404444

torchao/quantization/quantize_/common/kernel_preference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class KernelPreference(str, Enum):
2626
"""
2727
TORCH = "torch"
2828

29-
"""Use fbgemm quantize and quantized mm kernels, requires fbgemm_gpu_genai library
29+
"""Use quantize and quantized mm kernels from fbgemm_gpu_genai library, requires fbgemm_gpu_genai library
3030
"""
3131
FBGEMM = "fbgemm"
3232

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
preprocess_data,
2323
preprocess_scale,
2424
)
25-
from torchao.quantization.granularity import PerRow
25+
from torchao.quantization.granularity import PerRow, PerTensor
2626
from torchao.quantization.observer import get_block_size
2727
from torchao.quantization.quant_primitives import (
2828
_choose_scale_float8,
@@ -177,32 +177,61 @@ def to_float8(
177177
block_size = get_block_size(hp_tensor.shape, granularity)
178178
block_size = list(block_size)
179179

180-
# for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance
180+
kernel_choice = None
181181
if (
182182
kernel_preference == KernelPreference.AUTO
183183
and _is_fbgemm_genai_gpu_available()
184-
and (
185-
tuple(block_size)
186-
== (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],)
187-
)
184+
and is_sm_at_least_90()
185+
and isinstance(granularity, PerRow)
186+
and float8_dtype == torch.float8_e4m3fn
187+
and hp_value_lb is None
188188
):
189-
assert float8_dtype == torch.float8_e4m3fn, (
190-
f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}"
189+
# if kernel_preference is AUTO and per row quantization
190+
# we'll use fbgemm quantize kernel for best performance
191+
kernel_choice = "fbgemm"
192+
elif kernel_preference == KernelPreference.FBGEMM:
193+
# if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel
194+
assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), (
195+
"Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (>= H100)"
196+
)
197+
assert hp_value_lb is None, (
198+
"hp_value_lb should not be specified if with KerenelPreference.FBGEMM"
191199
)
200+
kernel_choice = "fbgemm"
201+
else:
202+
# fallback quantize kernel for everything else will be torch
203+
kernel_choice = "torch"
204+
205+
if kernel_choice == "fbgemm":
206+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
192207
if hp_value_ub is not None:
193208
maybe_hp_value_ub_tensor = torch.tensor(
194209
hp_value_ub, dtype=torch.float, device=hp_tensor.device
195210
)
196211
else:
197212
maybe_hp_value_ub_tensor = None
198-
data, scale = torch.ops.triton.quantize_fp8_row(
199-
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
200-
)
201-
scale_shape = []
202-
for i in range(hp_tensor.ndim):
203-
scale_shape.append(hp_tensor.shape[i] // block_size[i])
204-
scale = scale.reshape(*scale_shape)
213+
if isinstance(granularity, PerRow):
214+
data, scale = torch.ops.triton.quantize_fp8_row(
215+
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
216+
)
217+
scale_shape = []
218+
for i in range(hp_tensor.ndim):
219+
scale_shape.append(hp_tensor.shape[i] // block_size[i])
220+
scale = scale.reshape(*scale_shape)
221+
else:
222+
assert isinstance(granularity, PerTensor), (
223+
f"Expected per tensor, got {granularity}"
224+
)
225+
# current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
226+
# TODO: enable after this is working
227+
# data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor(
228+
# hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
229+
# )
230+
raise NotImplementedError(
231+
"Currently KernelPreference.FBGEMM does not work for per tensor float8 quant"
232+
)
205233
else:
234+
assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}"
206235
scale = _choose_scale_float8(
207236
hp_tensor,
208237
float8_dtype=float8_dtype,

0 commit comments

Comments
 (0)