diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 8a6b598789622d..3161f63293a997 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -482,8 +482,13 @@ def main(): for name, param in model.named_parameters(): # Shard all parameters along a single axis print('> Sharding tensor', name) - shape = (num_devices,) + (1,) * (len(param.shape) - 1) - mesh = xs.Mesh(device_ids, shape) + + # Shard along the largest dimension + import numpy as np + max_dim = np.argmax(param.shape) + shape = [1] * len(param.shape) + shape[max_dim] = num_devices + mesh = xs.Mesh(device_ids, tuple(shape)) xs.mark_sharding(param, mesh, range(len(param.shape)))