Skip to content

Commit 2fe4beb

Browse files
committed
Update images and some wording
Signed-off-by: Leah Karasek <karasek@google.com>
1 parent b4bcc0a commit 2fe4beb

File tree

5 files changed

+3
-1
lines changed

5 files changed

+3
-1
lines changed

_posts/2025-10-15-vllm-tpu.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ Although vLLM TPU with PTXLA was a major accomplishment, we needed to continue t
4040

4141
This new vLLM TPU redesign with [tpu-inference](http://tpu.vllm.ai) aims to optimize performance and extensibility by supporting PyTorch (via [Torchax](https://google.github.io/torchax/)) and [JAX](https://docs.jax.dev/en/latest/index.html) within a single unified JAX→XLA lowering path. Let’s dive into what this means:
4242

43-
JAX employs just-in-time (JIT) compilation to optimize Python functions for target hardware, such as TPUs. This process involves tracing a function to record the sequence of JAX primitives, which are then used to generate an XLA (Accelerated Linear Algebra) High-Level Optimizer (HLO) graph. The XLA compiler subsequently applies automatic hardware-specific optimizations. Compared to PyTorch/XLA, JAX is a more mature stack, generally offering superior coverage and performance for its primitives, particularly when implementing complex parallelism strategies.
43+
Compared to PyTorch/XLA, JAX is a more mature stack, generally offering superior coverage and performance for its primitives, particularly when implementing complex parallelism strategies.
4444

4545
For this reason, vLLM TPU now uses JAX as the lowering path for all vLLM models, benefiting from significant performance improvements, even when the model definition is written in PyTorch. This decision allows us to move faster and smarter, abstracting away higher level frameworks to focus on kernel development and compiler optimizations. Remember, to XLA, Torchax and JAX use the same high performance primitives ahead of compilation. You can read more about it [here](https://github.com/vllm-project/tpu-inference/blob/main/docs/developer_guides/torchax_model_development.md).
4646

47+
Although this is our current design, we will always strive to achieve the best performance possible on TPU and plan to evaluate a native PyTorch port on TPU in the future for vLLM TPU.
48+
4749
> [!IMPORTANT]
4850
> **Takeaway #1**: vLLM TPU now lowers all models with JAX. Without making any changes to the model code (e.g. [llama.py](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)), vLLM TPU now achieves ~20% higher throughput performance, simply because it now leverages JAX's mature, high-performance primitives to generate the HLO graph that is then compiled by XLA.
4951
10.9 KB
Loading
11.9 KB
Loading
-13.4 KB
Loading
7.35 KB
Loading

0 commit comments

Comments
 (0)