Skip to content

Commit

Permalink
Introduce tensor sharding (#14)
Browse files Browse the repository at this point in the history
Summary:
This pull request introduce a new way to do sharding which allow weights to be sharded in two dimensional mesh, i.e., (fsdp, tensor), and then the input to be sharded according to the fsdp dimension.

To enable it, pass --spmd_tensor_sharding 2, 2 is the tensor dimension, the fsdp dimension will be auto calculated according to num_devices // 2.

Test Plan:
Test it on a V4-8 with 2B LLaMA.
  • Loading branch information
alanwaketan authored Jul 25, 2023
1 parent e7ea6ea commit ba5c61d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
24 changes: 23 additions & 1 deletion examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ class ModelArguments:
)
},
)
spmd_tensor_sharding: int = field(
default=0,
metadata={
"help": (
"Will apply XLA SPMD to shard the weights along two dimensions (num_devices / spmd_tensor_sharding, spmd_tensor_sharding)"
)
},
)

def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down Expand Up @@ -265,6 +273,7 @@ def main():

training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding
training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding
training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
Expand Down Expand Up @@ -490,7 +499,20 @@ def main():
shape[max_dim] = num_devices
mesh = xs.HybridMesh(ici_mesh_shape=tuple(shape))
xs.mark_sharding(param, mesh, range(len(param.shape)))

elif model_args.spmd_tensor_sharding > 0:
print('Applying 2 dimensions sharding to all parameters')
for name, param in model.named_parameters():
# Shard all parameters along two axis except 1D tensors
print('> Sharding tensor', name, param.shape)
tensor = model_args.spmd_tensor_sharding
fsdp = num_devices // tensor
assert fsdp * tensor == num_devices
mesh = xs.Mesh(device_ids, (fsdp, tensor))
if len(param.shape) == 1:
xs.mark_sharding(param, mesh, (1,))
else:
assert len(param.shape) == 2
xs.mark_sharding(param, mesh, range(len(param.shape)))

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down
19 changes: 14 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,15 +1450,23 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):

def _xla_sharded_dataloader(self, dataloader):
if is_torch_tpu_available():
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla.distributed.parallel_loader as pl
num_devices = xr.global_device_count()
device_ids = np.arange(num_devices)

sharding_spec = None
if self.args.spmd_batch_sharding:
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla.distributed.parallel_loader as pl
num_devices = xr.global_device_count()
device_ids = np.arange(num_devices)
mesh = xs.Mesh(device_ids, (num_devices, 1))
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
elif self.args.spmd_tensor_sharding > 0:
tensor = self.args.spmd_tensor_sharding
fsdp = num_devices // tensor
mesh = xs.Mesh(device_ids, (fsdp, tensor))
partition_spec = (0, None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)

return pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4)
else:
return dataloader
Expand Down Expand Up @@ -1818,6 +1826,7 @@ def _inner_training_loop(
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

if step == profile_step and epoch == profile_epoch:
import tempfile
trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir or tempfile.mkdtemp(), profile_duration or 20000)
Thread(target=trace).start()

Expand Down

0 comments on commit ba5c61d

Please sign in to comment.