Skip to content

Commit be45f50

Browse files
alanwaketanvanbasten23
authored andcommitted
Guard 2D sharding for activations and inputs (#18)
Summary: This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs. Test Plan: N/A.
1 parent 6e42456 commit be45f50

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -401,21 +401,22 @@ def forward(
401401
if not output_attentions:
402402
attn_weights = None
403403

404-
# Apply 2D sharding:
405-
# activation (data,, None, model)
406-
import torch_xla.core.xla_model as xm
407-
import torch_xla.experimental.xla_sharding as xs
408-
import torch_xla.runtime as xr
409-
import torch_xla
410-
num_devices = xr.global_runtime_device_count()
411-
device_ids = torch.arange(num_devices)
412-
print('> Sharding activations', attn_output.shape)
413-
model = self.spmd_2d_sharding
414-
data = num_devices // model
415-
assert model * data == num_devices
416-
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
417-
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
418-
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))
404+
if self.spmd_2d_sharding > 0:
405+
# Apply 2D sharding:
406+
# activation (data,, None, model)
407+
import torch_xla.core.xla_model as xm
408+
import torch_xla.experimental.xla_sharding as xs
409+
import torch_xla.runtime as xr
410+
import torch_xla
411+
num_devices = xr.global_runtime_device_count()
412+
device_ids = torch.arange(num_devices)
413+
print('> Sharding activations', attn_output.shape)
414+
model = self.spmd_2d_sharding
415+
data = num_devices // model
416+
assert model * data == num_devices
417+
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
418+
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
419+
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))
419420

420421
return attn_output, attn_weights, past_key_value
421422

@@ -1015,21 +1016,22 @@ def forward(
10151016
# embed positions
10161017
# Is this the input to the model?
10171018
hidden_states = inputs_embeds
1018-
# Apply 2D sharding:
1019-
# input (data,, None, model)
1020-
import torch_xla.core.xla_model as xm
1021-
import torch_xla.experimental.xla_sharding as xs
1022-
import torch_xla.runtime as xr
1023-
import torch_xla
1024-
num_devices = xr.global_runtime_device_count()
1025-
device_ids = torch.arange(num_devices)
1026-
print('> Sharding hidden_states', hidden_states.shape)
1027-
model = self.spmd_2d_sharding
1028-
data = num_devices // model
1029-
assert model * data == num_devices
1030-
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
1031-
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
1032-
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))
1019+
if self.spmd_2d_sharding > 0:
1020+
# Apply 2D sharding:
1021+
# input (data,, None, model)
1022+
import torch_xla.core.xla_model as xm
1023+
import torch_xla.experimental.xla_sharding as xs
1024+
import torch_xla.runtime as xr
1025+
import torch_xla
1026+
num_devices = xr.global_runtime_device_count()
1027+
device_ids = torch.arange(num_devices)
1028+
print('> Sharding hidden_states', hidden_states.shape)
1029+
model = self.spmd_2d_sharding
1030+
data = num_devices // model
1031+
assert model * data == num_devices
1032+
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
1033+
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
1034+
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))
10331035

10341036
# decoder layers
10351037
all_hidden_states = () if output_hidden_states else None

0 commit comments

Comments
 (0)