Skip to content

Commit 95f42d0

Browse files
committed
support select.int for Float8Tensor
Summary: This is useful for stitching together 2D weights to a 3D weight, specifically this happens in vLLM for HF models where expert weights are 2D. Test Plan: ```bash pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -x -s -k index ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 8e2ca35 commit 95f42d0

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
quantize_,
2626
)
2727
from torchao.quantization.quantize_.common import KernelPreference
28+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor
2829
from torchao.quantization.utils import compute_error
2930
from torchao.testing.utils import TorchAOIntegrationTestCase
3031
from torchao.utils import (
@@ -446,6 +447,23 @@ def test_expected_gpu_kernel_fbgemm(self):
446447
".run("
447448
).run(code[0])
448449

450+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
451+
def test_index_select(self):
452+
"""
453+
test that `x_0 = x[0]` works when `x` is a 3D `Float8Tensor`. This is
454+
useful when stitching checkpoints of `num_experts` 2D parameters into
455+
a single 3D parameter when converting between model definitions that
456+
use 2D and 3D parameters for their expert weights.
457+
"""
458+
459+
E, K, N = 128, 256, 512
460+
x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16)
461+
x_fp8 = Float8Tensor.from_hp(x)
462+
x_fp8_1 = x_fp8[1]
463+
torch.testing.assert_close(
464+
x_fp8.dequantize()[1], x_fp8_1.dequantize(), atol=0, rtol=0
465+
)
466+
449467

450468
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
451469

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,22 @@ def _(func, types, args, kwargs):
617617
return return_and_correct_aliasing(func, args, kwargs, new)
618618

619619

620+
@implements(aten.select.int)
621+
def _(func, types, args, kwargs):
622+
old_float8_tensor, dim, index = args
623+
assert dim == 0, f"Float8Tensor aten.select.int with {dim=} is not yet supported"
624+
new_float8_tensor = old_float8_tensor.__class__(
625+
old_float8_tensor.qdata[index],
626+
old_float8_tensor.scale[index],
627+
old_float8_tensor.block_size[1:],
628+
old_float8_tensor.mm_config,
629+
old_float8_tensor.act_quant_kwargs,
630+
old_float8_tensor.kernel_preference,
631+
old_float8_tensor.dtype,
632+
)
633+
return return_and_correct_aliasing(func, args, kwargs, new_float8_tensor)
634+
635+
620636
Float8Tensor.__module__ = "torchao.quantization"
621637

622638
# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`

0 commit comments

Comments
 (0)