Skip to content

Commit 0b2ab3e

Browse files
committed
Fix Float8Tensor quantize op kernrel preference dispatch
Summary: Previously we didn't handle kernel_preference == "fbgemm" properly for the quantize op, this PR makes sure we dispatch to fbgemm kernels when kernel_preference is fbgemm This doesn't have impact on BC, the serialized checkpoints will use AUTO which is going to be dispatched to fbgemm triton op for quantize, only thing is fixing the kernel choice for fbgemm kernel preference, which is supposed to be a developer facing API (we expect most users to just use AUTO without worrying about details) Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2883, branch: jerryzh168/stack/59
1 parent 2a53216 commit 0b2ab3e

File tree

3 files changed

+58
-17
lines changed

3 files changed

+58
-17
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ def test_fp8_linear_variants(
8585
kernel_preference: KernelPreference,
8686
sizes: Tuple,
8787
):
88+
if (
89+
isinstance(granularity, PerTensor)
90+
and kernel_preference == KernelPreference.FBGEMM
91+
):
92+
return unittest.skip(
93+
"per tensor with fbgemm kernel preferece does not work yet"
94+
)
95+
8896
error_message = None
8997
if isinstance(granularity, PerRow):
9098
if mode == "dynamic" and dtype != torch.bfloat16:
@@ -237,7 +245,11 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
237245
other_kernel_preferences = [
238246
KernelPreference.AUTO,
239247
]
240-
if _is_fbgemm_genai_gpu_available() and is_sm_at_least_90():
248+
if (
249+
_is_fbgemm_genai_gpu_available()
250+
and is_sm_at_least_90()
251+
and not isinstance(granularity, PerTensor)
252+
):
241253
other_kernel_preferences.append(KernelPreference.FBGEMM)
242254

243255
quantized_outputs = {}

torchao/quantization/quantize_/common/kernel_preference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class KernelPreference(str, Enum):
2626
"""
2727
TORCH = "torch"
2828

29-
"""Use fbgemm quantize and quantized mm kernels, requires fbgemm_gpu_genai library
29+
"""Use quantize and quantized mm kernels from fbgemm_gpu_genai library, requires fbgemm_gpu_genai library
3030
"""
3131
FBGEMM = "fbgemm"
3232

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
preprocess_data,
2323
preprocess_scale,
2424
)
25-
from torchao.quantization.granularity import PerRow
25+
from torchao.quantization.granularity import PerRow, PerTensor
2626
from torchao.quantization.observer import get_block_size
2727
from torchao.quantization.quant_primitives import (
2828
_choose_scale_float8,
@@ -177,32 +177,61 @@ def to_float8(
177177
block_size = get_block_size(hp_tensor.shape, granularity)
178178
block_size = list(block_size)
179179

180-
# for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance
180+
kernel_choice = None
181181
if (
182182
kernel_preference == KernelPreference.AUTO
183183
and _is_fbgemm_genai_gpu_available()
184-
and (
185-
tuple(block_size)
186-
== (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],)
187-
)
184+
and is_sm_at_least_90()
185+
and isinstance(granularity, PerRow)
186+
and float8_dtype == torch.float8_e4m3fn
187+
and hp_value_lb is None
188188
):
189-
assert float8_dtype == torch.float8_e4m3fn, (
190-
f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}"
189+
# if kernel_preference is AUTO and per row quantization
190+
# we'll use fbgemm quantize kernel for best performance
191+
kernel_choice = "fbgemm"
192+
elif kernel_preference == KernelPreference.FBGEMM:
193+
# if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel
194+
assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), (
195+
"Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (> H100)"
196+
)
197+
assert hp_value_lb is None, (
198+
"hp_value_lb should not be specified if with KerenelPreference.FBGEMM"
191199
)
200+
kernel_choice = "fbgemm"
201+
else:
202+
# fallback quantize kernel for everything else will be torch
203+
kernel_choice = "torch"
204+
205+
if kernel_choice == "fbgemm":
206+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
192207
if hp_value_ub is not None:
193208
maybe_hp_value_ub_tensor = torch.tensor(
194209
hp_value_ub, dtype=torch.float, device=hp_tensor.device
195210
)
196211
else:
197212
maybe_hp_value_ub_tensor = None
198-
data, scale = torch.ops.triton.quantize_fp8_row(
199-
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
200-
)
201-
scale_shape = []
202-
for i in range(hp_tensor.ndim):
203-
scale_shape.append(hp_tensor.shape[i] // block_size[i])
204-
scale = scale.reshape(*scale_shape)
213+
if isinstance(granularity, PerRow):
214+
data, scale = torch.ops.triton.quantize_fp8_row(
215+
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
216+
)
217+
scale_shape = []
218+
for i in range(hp_tensor.ndim):
219+
scale_shape.append(hp_tensor.shape[i] // block_size[i])
220+
scale = scale.reshape(*scale_shape)
221+
else:
222+
assert isinstance(granularity, PerTensor), (
223+
f"Expected per tensor, got {granularity}"
224+
)
225+
# current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
226+
# TODO: enable after this is working
227+
# data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor(
228+
# hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
229+
# )
230+
raise NotImplementedError(
231+
"Currently KernelPreference.FBGEMM does not work for per tensor float8 quant"
232+
)
205233
else:
234+
assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}"
206235
scale = _choose_scale_float8(
207236
hp_tensor,
208237
float8_dtype=float8_dtype,

0 commit comments

Comments
 (0)