Skip to content

Commit 5f6ec32

Browse files
committed
Fix Float8Tensor quantize op kernrel preference dispatch
Summary: Previously we didn't handle kernel_preference == "fbgemm" properly for the quantize op, this PR makes sure we dispatch to fbgemm kernels when kernel_preference is fbgemm This doesn't have much impact on BC, the serialized checkpoints will use AUTO which is going to be dispatched to triton op for quantize, only thing is fixing the kernel choice for fbgemm kernel preference, which is supposed to be a developer facing API (we expect most users to just use AUTO without worrying about details) Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2883, branch: jerryzh168/stack/59
1 parent 3bf21d0 commit 5f6ec32

File tree

2 files changed

+62
-8
lines changed

2 files changed

+62
-8
lines changed

torchao/quantization/quantize_/common/kernel_preference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,9 @@ class KernelPreference(str, Enum):
3030
"""
3131
FBGEMM = "fbgemm"
3232

33+
"""Use triton quantize and quantized mm kernels (if available), requires fbgemm_gpu_genai library, if no triton kernel for the quantize op or mm kernel is available, we'll fallback to torch ops
34+
"""
35+
TRITON = "triton"
36+
3337

3438
torch.serialization.add_safe_globals([KernelPreference])

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
preprocess_data,
2323
preprocess_scale,
2424
)
25-
from torchao.quantization.granularity import PerRow
25+
from torchao.quantization.granularity import PerRow, PerTensor
2626
from torchao.quantization.observer import get_block_size
2727
from torchao.quantization.quant_primitives import (
2828
_choose_scale_float8,
@@ -177,18 +177,33 @@ def to_float8(
177177
block_size = get_block_size(hp_tensor.shape, granularity)
178178
block_size = list(block_size)
179179

180-
# for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance
180+
kernel_choice = None
181181
if (
182182
kernel_preference == KernelPreference.AUTO
183183
and _is_fbgemm_genai_gpu_available()
184-
and (
185-
tuple(block_size)
186-
== (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],)
184+
and is_sm_at_least_90()
185+
and isinstance(granularity, PerRow)
186+
and float8_dtype == torch.float8_e4m3fn
187+
and hp_value_lb is None
188+
) or kernel_preference == KernelPreference.TRITON:
189+
# for per row quantization and kernel_preference auto setting
190+
# we'll use triton quantize kernel for best performance
191+
kernel_choice = "triton"
192+
elif kernel_preference == KernelPreference.FBGEMM:
193+
# we'll use fbgemm quantize kernel if it's explicitly chosen by user
194+
assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), (
195+
"Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (> H100)"
187196
)
188-
):
189-
assert float8_dtype == torch.float8_e4m3fn, (
190-
f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}"
197+
assert hp_value_lb is None, (
198+
"hp_value_lb should not be specified if FBGEMM is explicitly chosen"
191199
)
200+
kernel_choice = "fbgemm"
201+
else:
202+
# fallback quantize kernel for everything else will be torch
203+
kernel_choice = "torch"
204+
205+
if kernel_choice == "triton":
206+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
192207
if hp_value_ub is not None:
193208
maybe_hp_value_ub_tensor = torch.tensor(
194209
hp_value_ub, dtype=torch.float, device=hp_tensor.device
@@ -202,7 +217,39 @@ def to_float8(
202217
for i in range(hp_tensor.ndim):
203218
scale_shape.append(hp_tensor.shape[i] // block_size[i])
204219
scale = scale.reshape(*scale_shape)
220+
elif kernel_choice == "fbgemm":
221+
assert hp_value_lb is None, f"{hp_value_lb=} is not supported"
222+
if hp_value_ub is not None:
223+
maybe_hp_value_ub_tensor = torch.tensor(
224+
hp_value_ub, dtype=torch.float, device=hp_tensor.device
225+
)
226+
else:
227+
maybe_hp_value_ub_tensor = None
228+
# not used
229+
num_tokens = torch.empty([hp_tensor.size(0)], device=hp_tensor.device)
230+
if isinstance(granularity, PerRow):
231+
data, scale = torch.ops.fbgemm.quantize_fp8_per_row(
232+
hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
233+
)
234+
else:
235+
assert isinstance(granularity, PerTensor), (
236+
f"Expected per tensor, got {granularity}"
237+
)
238+
# TODO: use fbgemm kernel when it works
239+
# current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered
240+
# data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor(
241+
# hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor
242+
# )
243+
scale = _choose_scale_float8(
244+
hp_tensor,
245+
float8_dtype=float8_dtype,
246+
block_size=block_size,
247+
hp_value_lb=hp_value_lb,
248+
hp_value_ub=hp_value_ub,
249+
)
250+
data = _quantize_affine_float8(hp_tensor, scale, float8_dtype)
205251
else:
252+
assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}"
206253
scale = _choose_scale_float8(
207254
hp_tensor,
208255
float8_dtype=float8_dtype,
@@ -256,6 +303,9 @@ def _(func, types, args, kwargs):
256303
kernel_choice = "fbgemm"
257304
elif weight_tensor.kernel_preference == KernelPreference.FBGEMM:
258305
kernel_choice = "fbgemm"
306+
elif weight_tensor.kernel_preference == KernelPreference.TRITON:
307+
# no triton gemm op is available, so we'll fallback to torch
308+
kernel_choice = "torch"
259309
else:
260310
assert weight_tensor.kernel_preference == KernelPreference.TORCH, (
261311
f"{weight_tensor.kernel_preference=} not handled"

0 commit comments

Comments
 (0)