Skip to content

Commit

Permalink
Allow MpDeviceLoader to shard dictionaries of tensor (#8202)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 authored Oct 2, 2024
1 parent b9e65e7 commit cc631b9
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 16 deletions.
20 changes: 19 additions & 1 deletion docs/spmd_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,28 @@ PyTorch/XLA SPMD takes a single-device program, shards and executes it in parall
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
```

It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes:

```python
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
```

### Virtual Device Optimization

PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc.
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_dtensor_integration2.py"
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
151 changes: 151 additions & 0 deletions test/spmd/test_mp_input_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import sys
import numpy as np
import unittest

import torch
import torch_xla
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.parallel_loader as pl

xr.use_spmd()


class MpInputShardingTest(unittest.TestCase):

class fake_dataloader:

def __init__(self, batch, size=1):
self.batch = batch
self.batch_size = size
self.counter = 0

def __iter__(self):
return self

def __next__(self):
if self.counter < self.batch_size:
self.counter += 1
return self.batch
raise StopIteration

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_multiple_inputs(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={
'x': xs.ShardingSpec(mesh, ('x', None)),
'y': xs.ShardingSpec(mesh, ('x', None, None))
})
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_single_tensor(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_single_tensor_with_input_sharding_dict(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(ValueError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_none(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()

train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None)
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{replicated}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_missing_keys(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(KeyError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_not_dict(self):
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set -xue
python3 test/test_operations.py -v
python3 test/pjrt/test_runtime_tpu.py
python3 test/pjrt/test_collective_ops_tpu.py
python3 test/spmd/test_mp_input_sharding.py
python3 test/spmd/test_xla_sharding.py
python3 test/spmd/test_xla_virtual_device.py
python3 test/spmd/test_xla_distributed_checkpoint.py
Expand Down
79 changes: 64 additions & 15 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import queue
import threading
import torch
import torch_xla
Expand All @@ -12,7 +13,7 @@ class PerDeviceQueue(object):

def __init__(self, device, loader_prefetch_size, device_prefetch_size):
self.device = device
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.queue = kq.Queue(maxsize=device_prefetch_size)
self.close_queue_count = itertools.count()

Expand Down Expand Up @@ -47,6 +48,8 @@ def next(self):

item = self._loader.next_item(self._device)
if item is None:
if not self._loader._exception_queue.empty():
raise self._loader._exception_queue.get()
xm.mark_step()
raise StopIteration
return item
Expand All @@ -56,7 +59,7 @@ class ParallelLoader(object):
"""Wraps an existing PyTorch DataLoader with background data upload.
Args:
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
wrapped.
devices (`torch.device`...): The list of devices where the data has to be
sent. The i-th sample returned by the `loader` will be sent to `devices[i
Expand All @@ -74,26 +77,26 @@ class ParallelLoader(object):
host_to_device_transfer_threads (int, optional): The number of threads that
work in parallel to transfer data from loader queue to device queue.
Default: 1
input_sharding (ShardingSpec, optional): Sharding spec to apply to
compatible input tensors after loading.
Default: None
input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding
spec to apply to compatible input tensors after loading.
"""

def __init__(self,
loader,
cpu_loader,
devices,
batchdim=0,
batches_per_execution=1,
loader_prefetch_size=16,
device_prefetch_size=8,
host_to_device_transfer_threads=1,
input_sharding=None):
self._loader = loader
self._cpu_loader = cpu_loader
self._devices = [torch.device(x) for x in devices]
self._batchdim = batchdim
self._batches_per_execution = batches_per_execution
self._done = False
self._queues = dict()
self._exception_queue = queue.Queue()
self._input_sharding = input_sharding
self._threads = []
for device in self._devices:
Expand Down Expand Up @@ -140,7 +143,7 @@ def close(self):
self._done = True
for dqueue in self._queues.values():
dqueue.queue.close()
dqueue.loader_queue.close()
dqueue.cpu_loader_queue.close()

for thread in self._threads:
thread.join()
Expand All @@ -151,7 +154,7 @@ def batches_per_execution(self):

def _loader_worker(self):
queues = list(self._queues.values())
data_iter = enumerate(self._loader)
data_iter = enumerate(self._cpu_loader)
batch = []

try:
Expand All @@ -163,21 +166,62 @@ def _loader_worker(self):
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
queues[queue_no].cpu_loader_queue.put(device_batch)
batch = []
finally:
for dqueue in queues:
dqueue.loader_queue.close_write()
dqueue.cpu_loader_queue.close_write()

def _get_batch(self, dqueue):
batch = []
while dqueue.queue.max_size() > len(batch):
item = dqueue.loader_queue.get()
while len(batch) < dqueue.queue.max_size():
item = dqueue.cpu_loader_queue.get()
if item is None:
break
batch.append(item)
return batch

def send_cpu_data_to_device(self, batches, device):
"""Move batch to device.
Args:
batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch
present in the cpu memory
device: TPU device where the batch should be moved
Returns:
result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the
input batch is a dict. Otherwise, returns a list of torch.Tensor.
"""
result = None
if isinstance(self._input_sharding, dict):
if not isinstance(batches[0], dict):
raise ValueError(
f"input batch should be a dict when input sharding is a dict.")
result = []
for batch in batches:
xla_batch = {}
missing_keys = []
for key, tensor in batch.items():
assert type(tensor) == torch.Tensor
sharding_spec = None
if self._input_sharding:
if key not in self._input_sharding:
missing_keys.append(key)
continue
sharding_spec = self._input_sharding[key]

# xla_tensor is a list of tensors.
xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec)
xla_batch[key] = xla_tensor[0]
if len(missing_keys) != 0:
# Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread.
raise KeyError(
f"Keys: {missing_keys} are missing from input_sharding.")
result.append(xla_batch)
else:
result = xm.send_cpu_data_to_device(batches, device, self._input_sharding)
return result

def _worker(self, dqueue, host_to_device_transfer_threads):
device = torch.device(dqueue.device)

Expand All @@ -187,8 +231,13 @@ def _worker(self, dqueue, host_to_device_transfer_threads):
if not batch:
break
with torch.no_grad():
batch = xm.send_cpu_data_to_device(batch, device,
self._input_sharding)
try:
batch = self.send_cpu_data_to_device(batch, device)
except Exception as e:
# _worker is being run in a daemon thread, raise the error
# will not work. Put the error in an error queue instead.
self._exception_queue.put(e)
break
for data in batch:
dqueue.queue.put(data)
finally:
Expand Down

0 comments on commit cc631b9

Please sign in to comment.