Skip to content

Commit e549679

Browse files
committed
enroll mxtensor in vllm integration tests
Summary: 1. enrolls mxtensor in existing vllm slice and copy test, make it pass by moving to TorchAOBaseTensor's copy 2. add an additional test for vllm narrow, make that test pass by fixing an incorrect slice implementation. This may be useful for other tensor, they can opt-in in separate PRs. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` also, this PR enables running mxfp4 weight-only Qwen MoE models in vllm Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 609a88e ghstack-comment-id: 3339065416 Pull Request resolved: #3081
1 parent 3947a7f commit e549679

File tree

3 files changed

+49
-19
lines changed

3 files changed

+49
-19
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from torchao.quantization import quantize_
2323
from torchao.quantization.utils import compute_error
24-
from torchao.testing.utils import skip_if_rocm
24+
from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm
2525
from torchao.utils import (
2626
is_sm_at_least_89,
2727
is_sm_at_least_100,
@@ -190,3 +190,31 @@ def test_inference_workflow_nvfp4(
190190
assert sqnr >= SQNR_THRESHOLD, (
191191
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
192192
)
193+
194+
195+
class VLLMIntegrationTestCase(TorchAOIntegrationTestCase):
196+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
197+
@pytest.mark.skipif(
198+
not torch_version_at_least("2.8.0"),
199+
reason="torch.compile requires PyTorch 2.8+",
200+
)
201+
def test_slice_and_copy_similar_to_vllm(self):
202+
config = MXFPInferenceConfig(
203+
activation_dtype=torch.float8_e4m3fn,
204+
weight_dtype=torch.float8_e4m3fn,
205+
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
206+
)
207+
self._test_slice_and_copy_similar_to_vllm(config)
208+
209+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
210+
@pytest.mark.skipif(
211+
not torch_version_at_least("2.8.0"),
212+
reason="torch.compile requires PyTorch 2.8+",
213+
)
214+
def test_narrow_similar_to_vllm(self):
215+
config = MXFPInferenceConfig(
216+
activation_dtype=torch.float8_e4m3fn,
217+
weight_dtype=torch.float8_e4m3fn,
218+
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
219+
)
220+
self._test_narrow_similar_to_vllm(config)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -836,9 +836,7 @@ def mx_slice(func, types, args, kwargs):
836836
end_block = -1 if end is None else end // x._block_size
837837

838838
# Slice the scale tensor accordingly
839-
sliced_scale = aten.slice.Tensor(
840-
scale_shaped, 1, start_block, end_block, step
841-
).unsqueeze(-1)
839+
sliced_scale = aten.slice.Tensor(scale_shaped, 1, start_block, end_block, step)
842840
else:
843841
raise ValueError(
844842
f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}"
@@ -861,20 +859,6 @@ def mx_slice(func, types, args, kwargs):
861859
)
862860

863861

864-
@implements([aten.copy_.default])
865-
def mx_copy_(func, types, args, kwargs):
866-
self = args[0]
867-
src = args[1]
868-
if MXTensor._same_metadata(self, src):
869-
self_tensors = self.__tensor_flatten__()[0]
870-
for tensor_name in self_tensors:
871-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
872-
return
873-
raise ValueError(
874-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
875-
)
876-
877-
878862
@implements([aten.clone.default])
879863
def mx_clone(func, types, args, kwargs):
880864
self = args[0]

torchao/testing/utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717

1818
import torchao
19+
from torchao.core.config import AOBaseConfig
1920
from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx
2021
from torchao.quantization import Int8WeightOnlyConfig, quantize_
2122
from torchao.quantization.quant_primitives import MappingType
@@ -426,7 +427,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
426427

427428

428429
class TorchAOIntegrationTestCase(common_utils.TestCase):
429-
def _test_slice_and_copy_similar_to_vllm(self, config):
430+
def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
430431
# making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
431432
# the test is similar to the linked code, but with some hardcoded arguments
432433
# and does not use tensor parallelism
@@ -607,6 +608,23 @@ def process_key(key: str) -> torch.Tensor:
607608
# make sure it runs
608609
moe_combined(input)
609610

611+
def _test_narrow_similar_to_vllm(self, config: AOBaseConfig):
612+
# this happens various times in vllm when slicing weights around
613+
614+
dtype = torch.bfloat16
615+
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
616+
quantize_(l, config)
617+
618+
orig = l.weight
619+
new = orig.narrow(1, 0, 1024)
620+
621+
for data_attr_name in new.tensor_data_names:
622+
orig_attr = getattr(orig, data_attr_name)
623+
new_attr = getattr(new, data_attr_name)
624+
assert len(orig_attr.shape) == len(new_attr.shape), (
625+
f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}"
626+
)
627+
610628

611629
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
612630
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)

0 commit comments

Comments
 (0)