-
Notifications
You must be signed in to change notification settings - Fork 560
Description
Motivation
For many years, torch_xla
has been the primary way for the community to run PyTorch programs on Cloud TPUs. It has successfully enabled the training of massive models by bringing the power of the XLA compiler to the PyTorch ecosystem.
The current implementation, while powerful, presents a developer experience that can sometimes feel distinct from "native" PyTorch. The reliance on a lazy tensor model and explicit graph tracing (xm.mark_step
) creates a separation from PyTorch's eager-first philosophy. This can introduce challenges in debugging, complicates integration with the broader PyTorch ecosystem, and requires users to learn a torch_xla
-specific set of APIs and concepts.
We believe we can deliver a more seamless and native experience for PyTorch users on TPUs. The goal is to provide the best of both worlds: the interactive, flexible development experience of PyTorch's eager mode and the world-class performance of the XLA compiler for scaled-out workloads.
Proposal: A Native TPU Backend
We propose a TPU backend for PyTorch that is designed to align with modern PyTorch architecture and eager-first design. The goal is to make a "native" device in PyTorch, where tensor.to('tpu')
feels just as natural and intuitive as tensor.to('cuda')
. This new direction aims to fully embrace PyTorch's eager mode while still leveraging the powerful XLA compiler for performance-critical code paths.
The core principles of this new stack are:
- XLA: Similarly to
torch_xla
, our proposal assumes that we can continue to rely on XLA as the underlying compiler infrastructure. However, we would call it in a profoundly different way which enables new techniques and a better user experience. Note that on TPU, compilation is required for the best performance — but it should be possible to hide the compile times. - Eager Mode with Deferred Execution: Similar to standard PyTorch eager mode, ops are being dispatched. However, the new stack can then choose to compile and execute individual ops, shorter or longer sequences of ops, or potential candidates for fusion clusters—all the way up to a full compile of a forward or backward pass.
Compilation would happen asynchronously, which means compilation of graphs and their execution could overlap, and compilation results would be cached. We would work with the XLA team to further reduce overall compile time overhead with techniques such as persistent deduping and by limiting inlining and unrolling. As a result, the compile time overhead would be drastically minimized even for larger incrementally compiled graphs. - JIT: This approach would enable a true just-in-time compilation engine with recompilation, feedback-directed optimizations, autotuning, and active memory management to avoid OOMs. With this, users would get the eager experience but with compiled performance after just a few inferences or training steps.
With these principles in mind, we could deliver on the following features:
- Eager Execution by Default: As described above, operations will appear as being eagerly executed, just as they do on CPU or GPU, even though they are being compiled in the background with minimal, and mostly hidden, compile time overhead. This would provide a familiar, intuitive, and much easier-to-debug workflow where users can inspect tensors and use standard Python tooling.
- Integration with
torch.compile
: For maximizing performance, TPU would integrate as a first-class backend fortorch.compile
. This would allow users to get the performance benefits of XLA compilation and TPUs at scale on their performance-critical code with a simple@torch.compile
decorator. - Distributed Training via DTensor: The new backend would natively support PyTorch's distributed APIs. This would allow users to leverage advanced, large-scale distributed training strategies like Fully Sharded Data Parallel (FSDP) and other model parallelism techniques out of the box, making it much simpler to scale up models.
- A More "PyTorch Native" Feel: The end goal is to abstract away the complexities of the underlying compiler. Developing for a TPU should not require a fundamentally different programming model. This would mean moving away from
torch_xla
-specific APIs and toward the standard PyTorch API surface. This approach would provide the best of both worlds: the interactive, flexible development experience of PyTorch's eager mode and the world-class performance of the XLA compiler for scaled-out workloads.
We Want Your Feedback!
We're excited for this direction, and to bring together PyTorch's eager mode and the XLA compiler in a way that helps the community achieve new levels of performance and scale. This is a significant undertaking, and we want to build it with the community. We're open to feedback on this direction.
- Does this proposal address the pain points you've experienced with
torch_xla?
- Are there specific workflows or PyTorch features whose support is critical for your work?
- What would be the most important factors for you when considering a migration from
torch_xla
to this new stack or from PyTorch on GPU?
Thank you for being a part of the PyTorch/XLA community. We're excited to build this next chapter with you.