diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 77b022ab613b18..1180e9cae2d172 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -164,6 +164,14 @@ class ModelArguments: ) }, ) + spmd_tensor_sharding: int = field( + default=0, + metadata={ + "help": ( + "Will apply XLA SPMD to shard the weights along two dimensions (num_devices / spmd_tensor_sharding, spmd_tensor_sharding)" + ) + }, + ) def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): @@ -265,6 +273,7 @@ def main(): training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding + training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. @@ -490,7 +499,20 @@ def main(): shape[max_dim] = num_devices mesh = xs.HybridMesh(ici_mesh_shape=tuple(shape)) xs.mark_sharding(param, mesh, range(len(param.shape))) - + elif model_args.spmd_tensor_sharding > 0: + print('Applying 2 dimensions sharding to all parameters') + for name, param in model.named_parameters(): + # Shard all parameters along two axis except 1D tensors + print('> Sharding tensor', name, param.shape) + tensor = model_args.spmd_tensor_sharding + fsdp = num_devices // tensor + assert fsdp * tensor == num_devices + mesh = xs.Mesh(device_ids, (fsdp, tensor)) + if len(param.shape) == 1: + xs.mark_sharding(param, mesh, (1,)) + else: + assert len(param.shape) == 2 + xs.mark_sharding(param, mesh, range(len(param.shape))) # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6b73d833beba37..5acba42a4cc97f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1450,15 +1450,23 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): def _xla_sharded_dataloader(self, dataloader): if is_torch_tpu_available(): + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla.distributed.parallel_loader as pl + num_devices = xr.global_device_count() + device_ids = np.arange(num_devices) + sharding_spec = None if self.args.spmd_batch_sharding: - import torch_xla.experimental.xla_sharding as xs - import torch_xla.runtime as xr - import torch_xla.distributed.parallel_loader as pl - num_devices = xr.global_device_count() - device_ids = np.arange(num_devices) mesh = xs.Mesh(device_ids, (num_devices, 1)) sharding_spec = xs.ShardingSpec(mesh, (0, 1)) + elif self.args.spmd_tensor_sharding > 0: + tensor = self.args.spmd_tensor_sharding + fsdp = num_devices // tensor + mesh = xs.Mesh(device_ids, (fsdp, tensor)) + partition_spec = (0, None) + sharding_spec = xs.ShardingSpec(mesh, partition_spec) + return pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4) else: return dataloader @@ -1818,6 +1826,7 @@ def _inner_training_loop( self.control = self.callback_handler.on_step_begin(args, self.state, self.control) if step == profile_step and epoch == profile_epoch: + import tempfile trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir or tempfile.mkdtemp(), profile_duration or 20000) Thread(target=trace).start()