Note: PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package. See our main README for all available builds, including GPU builds.
PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!
Take a look at one of our Kaggle notebooks to get started:
To install PyTorch/XLA a new VM:
pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
To update your existing training loop, make the following changes:
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index):
...
+ # Move the model paramters to your XLA device
+ model.to(xm.xla_device())
+
+ # MpDeviceLoader preloads data to the XLA device
+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())
- for inputs, labels in train_loader:
+ for inputs, labels in xla_train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+
+ # `xm.optimizer_step` combines gradients across replocas
+ xm.optimizer_step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # xmp.spawn automatically selects the correct world size
+ xmp.spawn(_mp_fn, args=())
If you're using DistributedDataParallel
, make the following changes:
import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend
def _mp_fn(rank, world_size):
...
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=tpu` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
+ xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())
- for inputs, labels in train_loader:
+ for inputs, labels in xla_train_loader:
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ xmp.spawn(_mp_fn, args=())
Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, GPU, CPU and...).
Our comprehensive user guides are available at:
Documentation for the latest release
Documentation for master branch
For all builds and all versions of torch-xla
, see our main GitHub
README.
If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).
The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!
See the contribution guide.
This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.
You can find additional useful reading materials in