Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce tensor sharding #14

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ class ModelArguments:
)
},
)
spmd_tensor_sharding: int = field(
default=0,
metadata={
"help": (
"Will apply XLA SPMD to shard the weights along two dimensions (num_devices / spmd_tensor_sharding, spmd_tensor_sharding)"
)
},
)

def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down Expand Up @@ -265,6 +273,7 @@ def main():

training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding
training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding
training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
Expand Down Expand Up @@ -490,7 +499,20 @@ def main():
shape[max_dim] = num_devices
mesh = xs.HybridMesh(ici_mesh_shape=tuple(shape))
xs.mark_sharding(param, mesh, range(len(param.shape)))

elif model_args.spmd_tensor_sharding > 0:
print('Applying 2 dimensions sharding to all parameters')
for name, param in model.named_parameters():
# Shard all parameters along two axis except 1D tensors
print('> Sharding tensor', name, param.shape)
tensor = model_args.spmd_tensor_sharding
fsdp = num_devices // tensor
assert fsdp * tensor == num_devices
mesh = xs.Mesh(device_ids, (fsdp, tensor))
if len(param.shape) == 1:
xs.mark_sharding(param, mesh, (1,))
else:
assert len(param.shape) == 2
xs.mark_sharding(param, mesh, range(len(param.shape)))

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down
19 changes: 14 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,15 +1450,23 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):

def _xla_sharded_dataloader(self, dataloader):
if is_torch_tpu_available():
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla.distributed.parallel_loader as pl
num_devices = xr.global_device_count()
device_ids = np.arange(num_devices)

sharding_spec = None
if self.args.spmd_batch_sharding:
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla.distributed.parallel_loader as pl
num_devices = xr.global_device_count()
device_ids = np.arange(num_devices)
mesh = xs.Mesh(device_ids, (num_devices, 1))
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
elif self.args.spmd_tensor_sharding > 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifying batch or fsdp sharding will silently override the tensor parallelism, can we assert that these flags are exclusive since tensor_sharding implies FSDP/batch sharding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's intended as you can do tensor_sharding on weights and batch sharding on the input.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see - so we can run 2D FSDP by specifying --spmd_batch_sharding and e.g. --spmd_tensor_sharding 4. I think specifying --spmd_fsdp_sharding with --spmd_tensor_sharding 4 will always ignore the tensor_sharding though, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that you can specify:
--spmd_batch_sharding --spmd_tensor_sharding 4
but not
--spmd_fsdp_sharding --spmd_tensor_sharding 4

Do you think that's clear? If not, I can do a follow up.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks Jiewen! I think what we have now is fine. We can follow up later to make the sharding more standard like MaxText has with ici_*_parallelism and dcn_*_parallelism parameters.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, for sure. Does HybridMesh does anything for you in a single slice?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it will rearrange the tiling assignment to optimize for the ICI connections. I would say we should always use HybridMesh, even for single slice.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to see if you have the MFU numbers to compare. I will make a change later.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do a quick test on v4-8 to get the MFU difference

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. It's not a priority.

tensor = self.args.spmd_tensor_sharding
fsdp = num_devices // tensor
mesh = xs.Mesh(device_ids, (fsdp, tensor))
partition_spec = (0, None)
sharding_spec = xs.ShardingSpec(mesh, partition_spec)

return pl.MpDeviceLoader(dataloader, self.args.device, input_sharding=sharding_spec, loader_prefetch_size=self.args.train_batch_size, device_prefetch_size=4)
else:
return dataloader
Expand Down Expand Up @@ -1818,6 +1826,7 @@ def _inner_training_loop(
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

if step == profile_step and epoch == profile_epoch:
import tempfile
trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir or tempfile.mkdtemp(), profile_duration or 20000)
Thread(target=trace).start()

Expand Down