diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 47e638463bf5..a26dfd460d8e 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250529 -torchvision==0.22.0.dev20250529 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250605 +torchvision==0.23.0.dev20250605 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250605-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py index ab6cd3069e1c..407a824d8174 100644 --- a/tests/tpu/test_moe_pallas.py +++ b/tests/tpu/test_moe_pallas.py @@ -27,7 +27,7 @@ # The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 @pytest.mark.parametrize("m", [8, 16, 64, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 8d2f8112d2d7..16a9f0959b5c 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -100,7 +100,8 @@ def init_device(self): # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to # fix this. It will be removed after the bug in XLA compiler is fixed. os.environ["LIBTPU_INIT_ARGS"] = ( - "--xla_tpu_force_1d_allreduce_at_chunk_count=1") + os.environ.get("LIBTPU_INIT_ARGS", "") + + " --xla_tpu_force_1d_allreduce_at_chunk_count=1") torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype)