Skip to content

Commit 2fd02da

Browse files
committed
Format Takeaways with admonitions in TPU blog
Signed-off-by: Leah Karasek <karasek@google.com>
1 parent 4e71ba5 commit 2fd02da

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

_posts/2025-10-15-vllm-tpu.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ JAX employs just-in-time (JIT) compilation to optimize Python functions for targ
4444

4545
For this reason, vLLM TPU now uses JAX as the lowering path for all vLLM models, benefiting from significant performance improvements, even when the model definition is written in PyTorch. This decision allows us to move faster and smarter, abstracting away higher level frameworks to focus on kernel development and compiler optimizations. Remember, to XLA, Torchax and JAX use the same high performance primitives ahead of compilation. You can read more about it [here](https://github.com/vllm-project/tpu-inference/blob/main/docs/developer_guides/torchax_model_development.md).
4646

47-
| Takeaway \#1: vLLM TPU now lowers all models with JAX. Without making any changes to the model code (e.g. [llama.py](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)), vLLM TPU now achieves ~20% higher throughput performance, simply because it now leverages JAX's mature, high-performance primitives to generate the HLO graph that is then compiled by XLA. |
48-
| :---- |
47+
> [!IMPORTANT]
48+
> **Takeaway #1**: vLLM TPU now lowers all models with JAX. Without making any changes to the model code (e.g. [llama.py](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)), vLLM TPU now achieves ~20% higher throughput performance, simply because it now leverages JAX's mature, high-performance primitives to generate the HLO graph that is then compiled by XLA.
4949
5050
#### A Closer Look
5151

@@ -79,8 +79,8 @@ Let’s take a closer look at what’s happening under the hood:
7979

8080
This unification effort reduces duplication by leveraging existing work from the vLLM community, leaving more time to optimize TPU kernels and the XLA compiler. For PyTorch (via Torchax) and JAX models, all kernels and compilers are shared.
8181

82-
| Takeaway \#2: vLLM TPU will now default to running the TPU-optimized model code in tpu-inference if it exists, otherwise, it will fallback to the PyTorch model code from vLLM upstream (lowered using JAX via [Torchax](https://google.github.io/torchax/user_guide/how-it-works/)). For most users, this is an implementation detail. |
83-
| :---- |
82+
> [!IMPORTANT]
83+
> **Takeaway #2**: vLLM TPU will now default to running the TPU-optimized model code in tpu-inference if it exists, otherwise, it will fallback to the PyTorch model code from vLLM upstream (lowered using JAX via [Torchax](https://google.github.io/torchax/user_guide/how-it-works/)). For most users, this is an implementation detail.
8484
8585
*If Torchax can run PyTorch model code out-of-the-box on TPU but still compiles using JAX JIT, why did we rewrite some models in tpu-inference? Isn’t that duplicative?*
8686

@@ -90,8 +90,8 @@ The real performance benefit and the reason why we support reimplemented models
9090

9191
The reason we need this flexibility is because logical design choices of a vLLM developer when implementing a model do not always favor TPU. This makes them different, not because of JAX vs Torchax, but because GPUs are different from TPUs, requiring different strategies for optimizing.
9292

93-
| Takeaway \#3: For any model, it's *all* JAX under the hood! Unless logical differences in the implementation cause TPU performance to suffer, models will likely not benefit from being rewritten natively in JAX. That said, it’s important to retain the flexibility of reimplementing models if it means we can get the best out of TPUs. |
94-
| :---- |
93+
> [!IMPORTANT]
94+
> **Takeaway #3**: For any model, it's *all* JAX under the hood! Unless logical differences in the implementation cause TPU performance to suffer, models will likely not benefit from being rewritten natively in JAX. That said, it’s important to retain the flexibility of reimplementing models if it means we can get the best out of TPUs.
9595
9696
#### Ragged Paged Attention V3: The Most Flexible and High Performance Attention Kernel for TPU Inference in OSS
9797

@@ -109,15 +109,15 @@ Although the Ragged Paged Attention v2 kernel provided a major uptick in perform
109109

110110
We will be writing a technical deep dive on RPA v3 soon, so please look out for it in our docs.
111111

112-
| Takeaway \#4: RPA v3 is both flexible and performant and serves as an excellent reference for production-grade Pallas kernel development in OSS. We are excited for TPU-friendly MoE and MLA kernels to land in OSS in similar fashion soon. |
113-
| :---- |
112+
> [!IMPORTANT]
113+
> **Takeaway #4**: RPA v3 is both flexible and performant and serves as an excellent reference for production-grade Pallas kernel development in OSS. We are excited for TPU-friendly MoE and MLA kernels to land in OSS in similar fashion soon.
114114
115115
#### Single Program, Multi-Data (SPMD)
116116

117117
This release introduces Single Program, Multi-Data ([SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data)) as the default programming model for vLLM TPU. Unlike the previous multi-worker model (adapted from GPU paradigms), SPMD is native to the XLA compiler. Developers write code for a single, massive device, and the XLA compiler automatically partitions models and tensors, inserting communication operations for optimal execution.
118118

119-
| Takeaway \#5: SPMD enables advanced optimizations like overlapping communication with computation. SPMD represents a strategic shift towards deeper, native TPU integration, promising higher performance through a TPU-centric, compiler-first operating model. |
120-
| :---- |
119+
> [!IMPORTANT]
120+
> **Takeaway #5**: SPMD enables advanced optimizations like overlapping communication with computation. SPMD represents a strategic shift towards deeper, native TPU integration, promising higher performance through a TPU-centric, compiler-first operating model.
121121
122122
#### Bringing it All Together
123123

@@ -134,8 +134,8 @@ This release introduces Single Program, Multi-Data ([SPMD](https://en.wikipedia.
134134

135135
vLLM TPU has come a very long way from the prototype performance in February 2025, reaching nearly **2x-5x performance** on those same workloads, while also improving model coverage and usability.
136136

137-
| Takeaway \#6: Today, vLLM TPU is nearly 5x more performant than the first TPU prototype back in Feb 2025. With this new foundation in place, developers and researchers will now be able to push the boundaries of TPU inference performance further than ever before in open source. |
138-
| :---- |
137+
> [!IMPORTANT]
138+
> **Takeaway #6**: Today, vLLM TPU is nearly 5x more performant than the first TPU prototype back in Feb 2025. With this new foundation in place, developers and researchers will now be able to push the boundaries of TPU inference performance further than ever before in open source.
139139
140140
### Models, Features, and What’s Next
141141

0 commit comments

Comments
 (0)