For this reason, vLLM TPU now uses JAX as the lowering path for all vLLM models, benefiting from significant performance improvements, even when the model definition is written in PyTorch. This decision allows us to move faster and smarter, abstracting away higher level frameworks to focus on kernel development and compiler optimizations. Remember, to XLA, Torchax and JAX use the same high performance primitives ahead of compilation. You can read more about it [here](https://github.com/vllm-project/tpu-inference/blob/main/docs/developer_guides/torchax_model_development.md).
0 commit comments