From ec6c2d6c05d8d86bc48422c918ab7f3ea3243c37 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 25 Jul 2023 11:40:37 -0700 Subject: [PATCH] Introduce tensor sharding (#14) Summary: This pull request introduce a new way to do sharding which allow weights to be sharded in two dimensional mesh, i.e., (fsdp, tensor), and then the input to be sharded according to the fsdp dimension. To enable it, pass --spmd_tensor_sharding 2, 2 is the tensor dimension, the fsdp dimension will be auto calculated according to num_devices // 2. Test Plan: Test it on a V4-8 with 2B LLaMA. --- examples/pytorch/language-modeling/run_clm.py | 24 ++++++++++++++++++- src/transformers/trainer.py | 19 +++++++++++---- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index efbd5e2bcd6b96..bd3fdb927e3f96 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -181,6 +181,14 @@ class ModelArguments: ) }, ) + spmd_tensor_sharding: int = field( + default=0, + metadata={ + "help": ( + "Will apply XLA SPMD to shard the weights along two dimensions (num_devices / spmd_tensor_sharding, spmd_tensor_sharding)" + ) + }, + ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): @@ -288,6 +296,7 @@ def main(): training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding + training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. @@ -516,7 +525,20 @@ def main(): shape[max_dim] = num_devices mesh = xs.HybridMesh(ici_mesh_shape=tuple(shape)) xs.mark_sharding(param, mesh, range(len(param.shape))) - + elif model_args.spmd_tensor_sharding > 0: + print('Applying 2 dimensions sharding to all parameters') + for name, param in model.named_parameters(): + # Shard all parameters along two axis except 1D tensors + print('> Sharding tensor', name, param.shape) + tensor = model_args.spmd_tensor_sharding + fsdp = num_devices // tensor + assert fsdp * tensor == num_devices + mesh = xs.Mesh(device_ids, (fsdp, tensor)) + if len(param.shape) == 1: + xs.mark_sharding(param, mesh, (1,)) + else: + assert len(param.shape) == 2 + xs.mark_sharding(param, mesh, range(len(param.shape))) # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8c8ef97cd44e70..99166813df0683 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1417,15 +1417,23 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): def _xla_sharded_dataloader(self, dataloader): if is_torch_tpu_available(): + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla.distributed.parallel_loader as pl + num_devices = xr.global_device_count() + device_ids = np.arange(num_devices) + sharding_spec = None if self.args.spmd_batch_sharding: - import torch_xla.experimental.xla_sharding as xs - import torch_xla.runtime as xr - import torch_xla.distributed.parallel_loader as pl - num_devices = xr.global_device_count() - device_ids = np.arange(num_devices) mesh = xs.Mesh(device_ids, (num_devices, 1)) sharding_spec = xs.ShardingSpec(mesh, (0, 1)) + elif self.args.spmd_tensor_sharding > 0: + tensor = self.args.spmd_tensor_sharding + fsdp = num_devices // tensor + mesh = xs.Mesh(device_ids, (fsdp, tensor)) + partition_spec = (0, None) + sharding_spec = xs.ShardingSpec(mesh, partition_spec) + return pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4) else: return dataloader @@ -1833,6 +1841,7 @@ def _inner_training_loop( self.control = self.callback_handler.on_step_begin(args, self.state, self.control) if step == profile_step and epoch == profile_epoch: + import tempfile trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir or tempfile.mkdtemp(), profile_duration or 20000) Thread(target=trace).start()