You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-10-15-vllm-tpu.md
+12-12Lines changed: 12 additions & 12 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -44,8 +44,8 @@ JAX employs just-in-time (JIT) compilation to optimize Python functions for targ
44
44
45
45
For this reason, vLLM TPU now uses JAX as the lowering path for all vLLM models, benefiting from significant performance improvements, even when the model definition is written in PyTorch. This decision allows us to move faster and smarter, abstracting away higher level frameworks to focus on kernel development and compiler optimizations. Remember, to XLA, Torchax and JAX use the same high performance primitives ahead of compilation. You can read more about it [here](https://github.com/vllm-project/tpu-inference/blob/main/docs/developer_guides/torchax_model_development.md).
46
46
47
-
| Takeaway \#1: vLLM TPU now lowers all models with JAX. Without making any changes to the model code (e.g. [llama.py](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)), vLLM TPU now achieves ~20% higher throughput performance, simply because it now leverages JAX's mature, high-performance primitives to generate the HLO graph that is then compiled by XLA. |
48
-
| :---- |
47
+
> [!IMPORTANT]
48
+
> **Takeaway #1**: vLLM TPU now lowers all models with JAX. Without making any changes to the model code (e.g. [llama.py](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)), vLLM TPU now achieves ~20% higher throughput performance, simply because it now leverages JAX's mature, high-performance primitives to generate the HLO graph that is then compiled by XLA.
49
49
50
50
#### A Closer Look
51
51
@@ -79,8 +79,8 @@ Let’s take a closer look at what’s happening under the hood:
79
79
80
80
This unification effort reduces duplication by leveraging existing work from the vLLM community, leaving more time to optimize TPU kernels and the XLA compiler. For PyTorch (via Torchax) and JAX models, all kernels and compilers are shared.
81
81
82
-
| Takeaway \#2: vLLM TPU will now default to running the TPU-optimized model code in tpu-inference if it exists, otherwise, it will fallback to the PyTorch model code from vLLM upstream (lowered using JAX via [Torchax](https://google.github.io/torchax/user_guide/how-it-works/)). For most users, this is an implementation detail. |
83
-
| :---- |
82
+
> [!IMPORTANT]
83
+
> **Takeaway #2**: vLLM TPU will now default to running the TPU-optimized model code in tpu-inference if it exists, otherwise, it will fallback to the PyTorch model code from vLLM upstream (lowered using JAX via [Torchax](https://google.github.io/torchax/user_guide/how-it-works/)). For most users, this is an implementation detail.
84
84
85
85
*If Torchax can run PyTorch model code out-of-the-box on TPU but still compiles using JAX JIT, why did we rewrite some models in tpu-inference? Isn’t that duplicative?*
86
86
@@ -90,8 +90,8 @@ The real performance benefit and the reason why we support reimplemented models
90
90
91
91
The reason we need this flexibility is because logical design choices of a vLLM developer when implementing a model do not always favor TPU. This makes them different, not because of JAX vs Torchax, but because GPUs are different from TPUs, requiring different strategies for optimizing.
92
92
93
-
| Takeaway \#3: For any model, it's *all* JAX under the hood! Unless logical differences in the implementation cause TPU performance to suffer, models will likely not benefit from being rewritten natively in JAX. That said, it’s important to retain the flexibility of reimplementing models if it means we can get the best out of TPUs. |
94
-
| :---- |
93
+
> [!IMPORTANT]
94
+
> **Takeaway #3**: For any model, it's *all* JAX under the hood! Unless logical differences in the implementation cause TPU performance to suffer, models will likely not benefit from being rewritten natively in JAX. That said, it’s important to retain the flexibility of reimplementing models if it means we can get the best out of TPUs.
95
95
96
96
#### Ragged Paged Attention V3: The Most Flexible and High Performance Attention Kernel for TPU Inference in OSS
97
97
@@ -109,15 +109,15 @@ Although the Ragged Paged Attention v2 kernel provided a major uptick in perform
109
109
110
110
We will be writing a technical deep dive on RPA v3 soon, so please look out for it in our docs.
111
111
112
-
| Takeaway \#4: RPA v3 is both flexible and performant and serves as an excellent reference for production-grade Pallas kernel development in OSS. We are excited for TPU-friendly MoE and MLA kernels to land in OSS in similar fashion soon. |
113
-
| :---- |
112
+
> [!IMPORTANT]
113
+
> **Takeaway #4**: RPA v3 is both flexible and performant and serves as an excellent reference for production-grade Pallas kernel development in OSS. We are excited for TPU-friendly MoE and MLA kernels to land in OSS in similar fashion soon.
114
114
115
115
#### Single Program, Multi-Data (SPMD)
116
116
117
117
This release introduces Single Program, Multi-Data ([SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data)) as the default programming model for vLLM TPU. Unlike the previous multi-worker model (adapted from GPU paradigms), SPMD is native to the XLA compiler. Developers write code for a single, massive device, and the XLA compiler automatically partitions models and tensors, inserting communication operations for optimal execution.
118
118
119
-
| Takeaway \#5: SPMD enables advanced optimizations like overlapping communication with computation. SPMD represents a strategic shift towards deeper, native TPU integration, promising higher performance through a TPU-centric, compiler-first operating model. |
120
-
| :---- |
119
+
> [!IMPORTANT]
120
+
> **Takeaway #5**: SPMD enables advanced optimizations like overlapping communication with computation. SPMD represents a strategic shift towards deeper, native TPU integration, promising higher performance through a TPU-centric, compiler-first operating model.
121
121
122
122
#### Bringing it All Together
123
123
@@ -134,8 +134,8 @@ This release introduces Single Program, Multi-Data ([SPMD](https://en.wikipedia.
134
134
135
135
vLLM TPU has come a very long way from the prototype performance in February 2025, reaching nearly **2x-5x performance** on those same workloads, while also improving model coverage and usability.
136
136
137
-
| Takeaway \#6: Today, vLLM TPU is nearly 5x more performant than the first TPU prototype back in Feb 2025. With this new foundation in place, developers and researchers will now be able to push the boundaries of TPU inference performance further than ever before in open source. |
138
-
| :---- |
137
+
> [!IMPORTANT]
138
+
> **Takeaway #6**: Today, vLLM TPU is nearly 5x more performant than the first TPU prototype back in Feb 2025. With this new foundation in place, developers and researchers will now be able to push the boundaries of TPU inference performance further than ever before in open source.
0 commit comments