Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions requirements/tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ setuptools==78.1.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.9.0.dev20250703
torchvision==0.24.0.dev20250703
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.9.0.dev20250711
torchvision==0.24.0.dev20250711
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

8 changes: 4 additions & 4 deletions tests/tpu/test_quantization_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@dataclass
class GSM8KAccuracyTestConfig:
model_name: str
excepted_value: float
expected_value: float

def get_model_args(self) -> str:
return (f"pretrained={self.model_name},"
Expand All @@ -25,13 +25,13 @@ def get_model_args(self) -> str:
ACCURACY_CONFIGS = [
GSM8KAccuracyTestConfig(
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
excepted_value=0.76), # no bias
expected_value=0.76), # no bias
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
# so only one of these tests can run in a single call to pytest. As
# a follow up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# excepted_value=0.66), # bias in QKV layers
# expected_value=0.66), # bias in QKV layers
]


Expand All @@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
batch_size="auto",
)

EXPECTED_VALUE = config.excepted_value
EXPECTED_VALUE = config.expected_value
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
Expand Down
32 changes: 32 additions & 0 deletions tests/v1/tpu/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,35 @@ def test_gemma3_27b_with_text_input_and_tp(
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text


@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
def test_w8a8_quantization(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
) -> None:
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
max_tokens = 5
tensor_parallel_size = 1
max_num_seqs = 4

prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

with vllm_runner(
model,
max_num_batched_tokens=64,
max_model_len=4096,
gpu_memory_utilization=0.7,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
output = vllm_outputs[0][1]

assert "1024" in output or "0, 1" in output
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,15 @@ def apply_weights(self,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)

import torch_xla.experimental.xla_quantized_matmul # noqa: F401
out = torch.ops.xla.quantized_matmul(x,
w_q,
w_s,
zero_point=None,
block_size=-1,
int4_weight=False,
quantize_activation=True)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out = out.to(x.dtype)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,
w_s,
quantize_activation=True,
)
Comment on lines +95 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The previous implementation using torch.ops.xla.quantized_matmul included a cast out = out.to(x.dtype) with the comment "quantized_matmul output is fp32, cast it down to bf16 for perf".

What is the output data type of the new torch.ops.xla.quantized_matmul_int8? If it also outputs in a higher precision (e.g., fp32) and x.dtype is a lower precision format like bfloat16, is a similar cast to x.dtype still necessary for performance or type consistency with subsequent layers? If the new op already handles this or outputs directly in x.dtype, then this change is fine.


# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])