diff --git a/docs/spmd_advanced.md b/docs/spmd_advanced.md index 4cd07a558c9..369fdfe2570 100644 --- a/docs/spmd_advanced.md +++ b/docs/spmd_advanced.md @@ -10,10 +10,28 @@ PyTorch/XLA SPMD takes a single-device program, shards and executes it in parall train_loader = pl.MpDeviceLoader( train_loader, # wraps PyTorch DataLoader device, - # assume 4d input and we want to shard at the batch dimension. + # assume 4d input and we want to shard at the batch dimension. input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None))) ``` +It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes: + +```python +# if batch = next(train_loader) looks like +# {'x': , 'y': } + +# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator +train_loader = pl.MpDeviceLoader( + train_loader, # wraps PyTorch DataLoader + device, + # specify different sharding for each input of the batch. + input_sharding={ + 'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)), + 'y': xs.ShardingSpec(input_mesh, ('data', None)) + } +) +``` + ### Virtual Device Optimization PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc. diff --git a/test/run_tests.sh b/test/run_tests.sh index 9a8c8fce9d5..0912d53ded5 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -245,6 +245,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_dtensor_integration2.py" run_test "$CDIR/spmd/test_xla_auto_sharding.py" run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" + run_test "$CDIR/spmd/test_mp_input_sharding.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" run_test "$CDIR/test_torch_distributed_xla_backend.py" diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py new file mode 100644 index 00000000000..6b78a3714e7 --- /dev/null +++ b/test/spmd/test_mp_input_sharding.py @@ -0,0 +1,151 @@ +import sys +import numpy as np +import unittest + +import torch +import torch_xla +from torch_xla import runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd import Mesh +import torch_xla.distributed.spmd as xs +import torch_xla.distributed.parallel_loader as pl + +xr.use_spmd() + + +class MpInputShardingTest(unittest.TestCase): + + class fake_dataloader: + + def __init__(self, batch, size=1): + self.batch = batch + self.batch_size = size + self.counter = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.counter < self.batch_size: + self.counter += 1 + return self.batch + raise StopIteration + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_multiple_inputs(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={ + 'x': xs.ShardingSpec(mesh, ('x', None)), + 'y': xs.ShardingSpec(mesh, ('x', None, None)) + }) + train_loader = iter(train_loader) + data = next(train_loader) + annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation_x, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation_y, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_single_tensor(self): + device = xm.xla_device() + batch = torch.randn((16, 128)) + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None))) + train_loader = iter(train_loader) + data = next(train_loader) + annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data)) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_error_single_tensor_with_input_sharding_dict(self): + device = xm.xla_device() + batch = torch.randn((16, 128)) + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) + train_loader = iter(train_loader) + with self.assertRaises(ValueError): + data = next(train_loader) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_input_sharding_none(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + + train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None) + train_loader = iter(train_loader) + data = next(train_loader) + annotation = '{replicated}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_error_missing_keys(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + mesh = xs.get_1d_mesh('x') + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) + train_loader = iter(train_loader) + with self.assertRaises(KeyError): + data = next(train_loader) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_input_sharding_not_dict(self): + device = xm.xla_device() + num_devices = xr.global_runtime_device_count() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))} + train_loader = self.fake_dataloader(batch) + mesh = xs.get_1d_mesh('x') + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None))) + train_loader = iter(train_loader) + data = next(train_loader) + annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation_x, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation_y, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 52d1de5b150..89661e29a58 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -5,6 +5,7 @@ set -xue python3 test/test_operations.py -v python3 test/pjrt/test_runtime_tpu.py python3 test/pjrt/test_collective_ops_tpu.py +python3 test/spmd/test_mp_input_sharding.py python3 test/spmd/test_xla_sharding.py python3 test/spmd/test_xla_virtual_device.py python3 test/spmd/test_xla_distributed_checkpoint.py diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index a0304b4523a..a177c92b59d 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -1,4 +1,5 @@ import itertools +import queue import threading import torch import torch_xla @@ -12,7 +13,7 @@ class PerDeviceQueue(object): def __init__(self, device, loader_prefetch_size, device_prefetch_size): self.device = device - self.loader_queue = kq.Queue(maxsize=loader_prefetch_size) + self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size) self.queue = kq.Queue(maxsize=device_prefetch_size) self.close_queue_count = itertools.count() @@ -47,6 +48,8 @@ def next(self): item = self._loader.next_item(self._device) if item is None: + if not self._loader._exception_queue.empty(): + raise self._loader._exception_queue.get() xm.mark_step() raise StopIteration return item @@ -56,7 +59,7 @@ class ParallelLoader(object): """Wraps an existing PyTorch DataLoader with background data upload. Args: - loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be + cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be wrapped. devices (`torch.device`...): The list of devices where the data has to be sent. The i-th sample returned by the `loader` will be sent to `devices[i @@ -74,13 +77,12 @@ class ParallelLoader(object): host_to_device_transfer_threads (int, optional): The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1 - input_sharding (ShardingSpec, optional): Sharding spec to apply to - compatible input tensors after loading. - Default: None + input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding + spec to apply to compatible input tensors after loading. """ def __init__(self, - loader, + cpu_loader, devices, batchdim=0, batches_per_execution=1, @@ -88,12 +90,13 @@ def __init__(self, device_prefetch_size=8, host_to_device_transfer_threads=1, input_sharding=None): - self._loader = loader + self._cpu_loader = cpu_loader self._devices = [torch.device(x) for x in devices] self._batchdim = batchdim self._batches_per_execution = batches_per_execution self._done = False self._queues = dict() + self._exception_queue = queue.Queue() self._input_sharding = input_sharding self._threads = [] for device in self._devices: @@ -140,7 +143,7 @@ def close(self): self._done = True for dqueue in self._queues.values(): dqueue.queue.close() - dqueue.loader_queue.close() + dqueue.cpu_loader_queue.close() for thread in self._threads: thread.join() @@ -151,7 +154,7 @@ def batches_per_execution(self): def _loader_worker(self): queues = list(self._queues.values()) - data_iter = enumerate(self._loader) + data_iter = enumerate(self._cpu_loader) batch = [] try: @@ -163,21 +166,62 @@ def _loader_worker(self): batch.append(data) if len(batch) == len(self._devices): for queue_no, device_batch in enumerate(batch): - queues[queue_no].loader_queue.put(device_batch) + queues[queue_no].cpu_loader_queue.put(device_batch) batch = [] finally: for dqueue in queues: - dqueue.loader_queue.close_write() + dqueue.cpu_loader_queue.close_write() def _get_batch(self, dqueue): batch = [] - while dqueue.queue.max_size() > len(batch): - item = dqueue.loader_queue.get() + while len(batch) < dqueue.queue.max_size(): + item = dqueue.cpu_loader_queue.get() if item is None: break batch.append(item) return batch + def send_cpu_data_to_device(self, batches, device): + """Move batch to device. + Args: + batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch + present in the cpu memory + device: TPU device where the batch should be moved + + Returns: + result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the + input batch is a dict. Otherwise, returns a list of torch.Tensor. + """ + result = None + if isinstance(self._input_sharding, dict): + if not isinstance(batches[0], dict): + raise ValueError( + f"input batch should be a dict when input sharding is a dict.") + result = [] + for batch in batches: + xla_batch = {} + missing_keys = [] + for key, tensor in batch.items(): + assert type(tensor) == torch.Tensor + sharding_spec = None + if self._input_sharding: + if key not in self._input_sharding: + missing_keys.append(key) + continue + sharding_spec = self._input_sharding[key] + + # xla_tensor is a list of tensors. + xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec) + xla_batch[key] = xla_tensor[0] + if len(missing_keys) != 0: + # Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread. + raise KeyError( + f"Keys: {missing_keys} are missing from input_sharding.") + result.append(xla_batch) + else: + result = xm.send_cpu_data_to_device(batches, device, self._input_sharding) + return result + def _worker(self, dqueue, host_to_device_transfer_threads): device = torch.device(dqueue.device) @@ -187,8 +231,13 @@ def _worker(self, dqueue, host_to_device_transfer_threads): if not batch: break with torch.no_grad(): - batch = xm.send_cpu_data_to_device(batch, device, - self._input_sharding) + try: + batch = self.send_cpu_data_to_device(batch, device) + except Exception as e: + # _worker is being run in a daemon thread, raise the error + # will not work. Put the error in an error queue instead. + self._exception_queue.put(e) + break for data in batch: dqueue.queue.put(data) finally: