diff --git a/requirements/tpu.txt b/requirements/tpu.txt index a4aee21d2bd9..db58b37c2b15 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.9.0.dev20250703 -torchvision==0.24.0.dev20250703 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.9.0.dev20250711 +torchvision==0.24.0.dev20250711 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index a13cf7064d54..6cefbae4bdd1 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -14,7 +14,7 @@ @dataclass class GSM8KAccuracyTestConfig: model_name: str - excepted_value: float + expected_value: float def get_model_args(self) -> str: return (f"pretrained={self.model_name}," @@ -25,13 +25,13 @@ def get_model_args(self) -> str: ACCURACY_CONFIGS = [ GSM8KAccuracyTestConfig( model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - excepted_value=0.76), # no bias + expected_value=0.76), # no bias # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As # a follow up, move this into the LM-EVAL section of the CI. # GSM8KAccuracyTestConfig( # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", - # excepted_value=0.66), # bias in QKV layers + # expected_value=0.66), # bias in QKV layers ] @@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): batch_size="auto", ) - EXPECTED_VALUE = config.excepted_value + EXPECTED_VALUE = config.expected_value measured_value = results["results"][TASK][FILTER] assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index c0d2192ad813..c8cd099a98cf 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -145,3 +145,35 @@ def test_gemma3_27b_with_text_input_and_tp( for output, answer in zip(vllm_outputs, answers): generated_text = output[1] assert answer in generated_text + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a basic test for TPU only") +def test_w8a8_quantization( + vllm_runner: type[VllmRunner], + monkeypatch: pytest.MonkeyPatch, +) -> None: + model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8" + max_tokens = 5 + tensor_parallel_size = 1 + max_num_seqs = 4 + + prompt = "The next numbers of the sequence " + ", ".join( + str(i) for i in range(1024)) + " are:" + example_prompts = [prompt] + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + with vllm_runner( + model, + max_num_batched_tokens=64, + max_model_len=4096, + gpu_memory_utilization=0.7, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + output = vllm_outputs[0][1] + + assert "1024" in output or "0, 1" in output diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 3de28af40aaa..0b931b2d8b81 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -90,16 +90,15 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) - import torch_xla.experimental.xla_quantized_matmul # noqa: F401 - out = torch.ops.xla.quantized_matmul(x, - w_q, - w_s, - zero_point=None, - block_size=-1, - int4_weight=False, - quantize_activation=True) - # `quantized_matmul` output is fp32, cast it down to bf16 for perf - out = out.to(x.dtype) + # Required to register custom ops. + import torch_xla.experimental.custom_kernel # noqa: F401 + out = torch.ops.xla.quantized_matmul_int8( + x, + w_q, + w_s, + quantize_activation=True, + ) + # Explicitly capture control flow to make dynamo happy. # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])