Skip to content

Commit b61dc5f

Browse files
authored
[TPU] update torch_xla pin (#19231)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent f8a1a2d commit b61dc5f

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

requirements/tpu.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ setuptools==78.1.0
1818
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1919
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
2020
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
21-
torch==2.8.0.dev20250529
22-
torchvision==0.22.0.dev20250529
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
21+
torch==2.8.0.dev20250605
22+
torchvision==0.23.0.dev20250605
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
2626

tests/tpu/test_moe_pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
2828
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
2929
@pytest.mark.parametrize("n", [128, 1024, 2048])
30-
@pytest.mark.parametrize("k", [128, 512, 1024])
30+
@pytest.mark.parametrize("k", [128, 511, 1024])
3131
@pytest.mark.parametrize("e", NUM_EXPERTS)
3232
@pytest.mark.parametrize("topk", TOP_KS)
3333
@pytest.mark.parametrize("ep_size", EP_SIZE)

vllm/v1/worker/tpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def init_device(self):
100100
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
101101
# fix this. It will be removed after the bug in XLA compiler is fixed.
102102
os.environ["LIBTPU_INIT_ARGS"] = (
103-
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
103+
os.environ.get("LIBTPU_INIT_ARGS", "") +
104+
" --xla_tpu_force_1d_allreduce_at_chunk_count=1")
104105
torch.set_grad_enabled(False)
105106
torch.set_default_dtype(self.model_config.dtype)
106107

0 commit comments

Comments
 (0)