PyTorch / XLA adds a new device, similarly to CPU and GPU devices. The following snippet creates an XLA tensor filled with random values, then prints the device and the contents of the tensor:
import torch
import torch_xla
import torch_xla_py.xla_model as xm
x = torch.randn(4, 2, device=xm.xla_device())
print(x.device)
print(x)
The XLA device is not a physical device but instead stands in for either a Cloud TPU or CPU.
The XLA readme describes all the options available to run on TPU or CPU.
To run a model, use the following API:
import torch_xla_py.xla_model as xm
import torch_xla_py.data_parallel as dp
devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(MNIST, device_ids=devices)
def train_loop_fn(model, loader, device, context):
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
model.train()
for _, (data, target) in loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
for epoch in range(1, num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
The same multi-core API can be used to run on a single core as well by setting the device_ids argument to the selected core. Passing []
as device_ids
causes the model to run using the PyTorch native CPU support.
Note the xm.optimizer_step(optimizer)
line which replaces the usual optimizer.step()
. This is required because of the way XLA tensors work: operations are not executed immediately, but rather added to a graph of pending operations which is only executed when its results are required. Using xm.optimizer_step(optimizer)
acts as an execution barrier which forces the evaluation of the graph accumulated for a single step. Without this barrier, the graph would only be evaluated when evaluating the accuracy of the model, which is only done at the end of an epoch, for this example. Even for small models, the accumulated graph would be too big to evaluate at the end of an entire epoch.
Check the full example showing how to train MNIST on TPU.
PyTorch / XLA behaves semantically like regular PyTorch and XLA tensors, implementing the full tensor interface. However, constraints in XLA and hardware, and the lazy evaluation model mean some patterns must be avoided:
-
Tensor shapes should be the same between iterations, or a low number of shape variations should be used. PyTorch / XLA automatically recompiles the graph every time new shapes are encountered. This means that, if the shapes don’t stabilize during training, more time will be spent compiling than running the model. Pad tensors to fixed sizes when possible. Direct or indirect uses of
nonzero
introduce dynamic shapes; for example, masked indexingbase[index]
whereindex
is a mask tensor. -
Certain operations don’t have native translations to XLA and therefore require transfer to the CPU memory, evaluation on CPU, and transfer of the result back to the XLA device. This is automatically handled by PyTorch / XLA, but doing too many such operations during the training step can lead to significant slowdowns. The
item()
operation is one such example and it is used in clip_grad_norm_. Below is an alternative implementation which avoids the need foritem()
:... else: device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters: param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters: p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
-
Loops with a different number of iterations between steps are subject to similar observations as tensor shapes. PyTorch / XLA automatically handles them, but they are seen as different execution graphs and require recompilations.
print(torch_xla._XLAC._xla_metrics_report())
can be used to print metrics at the end of each step to collect information regarding the number of compilations and operators that are part of the model but don’t have native XLA implementations.