Skip to content

Commit d909935

Browse files
committed
Enable 2D sharding (#17)
Summary: This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy. Let's say we have a 2D mesh (data, model) and data x model == num_devices: 1. input (data,, None, model) 2. embedding (model, data) 3. attn QKV (data, model) 4. attn O (model, data) 5. mlp gate, up (model, data) 6. mlp down (data, model) 7. activation (data,, None, model) Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated. TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
1 parent 674ab35 commit d909935

File tree

3 files changed

+88
-3
lines changed

3 files changed

+88
-3
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ class ModelArguments:
189189
)
190190
},
191191
)
192+
spmd_2d_sharding: int = field(
193+
default=0,
194+
metadata={
195+
"help": (
196+
"Will apply XLA SPMD to 2D sharding, i.e., weights + activations, and spmd_2d_sharding specifies the model dimension"
197+
)
198+
},
199+
)
192200

193201
def __post_init__(self):
194202
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
@@ -297,6 +305,7 @@ def main():
297305
training_args.spmd_batch_sharding = model_args.spmd_batch_sharding or model_args.spmd_fsdp_sharding
298306
training_args.spmd_fsdp_sharding = model_args.spmd_fsdp_sharding
299307
training_args.spmd_tensor_sharding = model_args.spmd_tensor_sharding
308+
training_args.spmd_2d_sharding = model_args.spmd_2d_sharding
300309

301310
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
302311
# information sent is the one passed as arguments along with your Python/PyTorch versions.
@@ -469,6 +478,8 @@ def main():
469478
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
470479
)
471480

481+
# Pass the 2d sharding config to the actual model.
482+
config.spmd_2d_sharding = model_args.spmd_2d_sharding
472483
if model_args.model_name_or_path:
473484
torch_dtype = (
474485
model_args.torch_dtype
@@ -539,6 +550,42 @@ def main():
539550
else:
540551
assert len(param.shape) == 2
541552
xs.mark_sharding(param, mesh, range(len(param.shape)))
553+
elif model_args.spmd_2d_sharding > 0:
554+
print('Applying 2D sharding to all parameters')
555+
for name, param in model.named_parameters():
556+
# Apply 2D sharding:
557+
# embedding (model, data)
558+
# attn QKV (data, model)
559+
# attn O (model, data)
560+
# mlp gate, up (model, data)
561+
# mlp down (data, model)
562+
print('> Sharding tensor', name, param.shape)
563+
mod = model_args.spmd_2d_sharding
564+
data = num_devices // mod
565+
assert mod * data == num_devices
566+
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, mod))
567+
model_data_mesh = xs.HybridMesh(ici_mesh_shape=(mod, data))
568+
569+
# We don't care about layernorm's weights, and
570+
# LLaMA doesn't use biases.
571+
if len(param.shape) == 1:
572+
continue
573+
574+
if 'embed_tokens' in name:
575+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
576+
elif 'q_proj' in name or 'k_proj' in name or 'v_proj' in name:
577+
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
578+
elif 'o_proj' in name:
579+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
580+
elif 'gate_proj' in name or 'up_proj' in name:
581+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
582+
elif 'down_proj' in name:
583+
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
584+
elif 'lm_head' in name: # Not sure what this is but has the same shape as embed_tokens
585+
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
586+
587+
import torch_xla
588+
print(torch_xla._XLAC._get_xla_sharding_spec(param))
542589

543590
# Preprocessing the datasets.
544591
# First we tokenize all the texts.

src/transformers/models/llama/modeling_llama.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ class LlamaAttention(nn.Module):
392392
def __init__(self, config: LlamaConfig):
393393
super().__init__()
394394
self.config = config
395+
# For PyTorch/XLA's SPMD 2D sharding
396+
self.spmd_2d_sharding = config.spmd_2d_sharding
395397
self.hidden_size = config.hidden_size
396398
self.num_heads = config.num_attention_heads
397399
self.head_dim = self.hidden_size // self.num_heads
@@ -540,6 +542,22 @@ def forward(
540542
if not output_attentions:
541543
attn_weights = None
542544

545+
# Apply 2D sharding:
546+
# activation (data,, None, model)
547+
import torch_xla.core.xla_model as xm
548+
import torch_xla.experimental.xla_sharding as xs
549+
import torch_xla.runtime as xr
550+
import torch_xla
551+
num_devices = xr.global_runtime_device_count()
552+
device_ids = torch.arange(num_devices)
553+
print('> Sharding activations', attn_output.shape)
554+
model = self.spmd_2d_sharding
555+
data = num_devices // model
556+
assert model * data == num_devices
557+
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
558+
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
559+
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))
560+
543561
return attn_output, attn_weights, past_key_value
544562

545563

@@ -935,6 +953,9 @@ class LlamaModel(LlamaPreTrainedModel):
935953

936954
def __init__(self, config: LlamaConfig):
937955
super().__init__(config)
956+
# For PyTorch/XLA's SPMD 2D sharding
957+
self.spmd_2d_sharding = config.spmd_2d_sharding
958+
938959
self.padding_idx = config.pad_token_id
939960
self.vocab_size = config.vocab_size
940961

@@ -1015,7 +1036,23 @@ def forward(
10151036
)
10161037

10171038
# embed positions
1039+
# Is this the input to the model?
10181040
hidden_states = inputs_embeds
1041+
# Apply 2D sharding:
1042+
# input (data,, None, model)
1043+
import torch_xla.core.xla_model as xm
1044+
import torch_xla.experimental.xla_sharding as xs
1045+
import torch_xla.runtime as xr
1046+
import torch_xla
1047+
num_devices = xr.global_runtime_device_count()
1048+
device_ids = torch.arange(num_devices)
1049+
print('> Sharding hidden_states', hidden_states.shape)
1050+
model = self.spmd_2d_sharding
1051+
data = num_devices // model
1052+
assert model * data == num_devices
1053+
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
1054+
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
1055+
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))
10191056

10201057
if self.gradient_checkpointing and self.training:
10211058
if use_cache:

src/transformers/trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,10 +1427,11 @@ def _xla_sharded_dataloader(self, dataloader):
14271427
if self.args.spmd_batch_sharding:
14281428
mesh = xs.Mesh(device_ids, (num_devices, 1))
14291429
sharding_spec = xs.ShardingSpec(mesh, (0, 1))
1430-
elif self.args.spmd_tensor_sharding > 0:
1431-
tensor = self.args.spmd_tensor_sharding
1430+
elif self.args.spmd_tensor_sharding > 0 or self.args.spmd_2d_sharding > 0:
1431+
assert self.args.spmd_tensor_sharding == 0 or self.args.spmd_2d_sharding == 0
1432+
tensor = self.args.spmd_tensor_sharding + self.args.spmd_2d_sharding
14321433
fsdp = num_devices // tensor
1433-
mesh = xs.Mesh(device_ids, (fsdp, tensor))
1434+
mesh = xs.HybridMesh(ici_mesh_shape=(fsdp, tensor))
14341435
partition_spec = (0, None)
14351436
sharding_spec = xs.ShardingSpec(mesh, partition_spec)
14361437

0 commit comments

Comments
 (0)