@@ -401,21 +401,22 @@ def forward(
401401 if not output_attentions :
402402 attn_weights = None
403403
404- # Apply 2D sharding:
405- # activation (data,, None, model)
406- import torch_xla .core .xla_model as xm
407- import torch_xla .experimental .xla_sharding as xs
408- import torch_xla .runtime as xr
409- import torch_xla
410- num_devices = xr .global_runtime_device_count ()
411- device_ids = torch .arange (num_devices )
412- print ('> Sharding activations' , attn_output .shape )
413- model = self .spmd_2d_sharding
414- data = num_devices // model
415- assert model * data == num_devices
416- data_model_mesh = xs .HybridMesh (ici_mesh_shape = (data , 1 , model ))
417- xs .mark_sharding (attn_output , data_model_mesh , (0 , 1 , 2 ))
418- print (torch_xla ._XLAC ._get_xla_sharding_spec (attn_output ))
404+ if self .spmd_2d_sharding > 0 :
405+ # Apply 2D sharding:
406+ # activation (data,, None, model)
407+ import torch_xla .core .xla_model as xm
408+ import torch_xla .experimental .xla_sharding as xs
409+ import torch_xla .runtime as xr
410+ import torch_xla
411+ num_devices = xr .global_runtime_device_count ()
412+ device_ids = torch .arange (num_devices )
413+ print ('> Sharding activations' , attn_output .shape )
414+ model = self .spmd_2d_sharding
415+ data = num_devices // model
416+ assert model * data == num_devices
417+ data_model_mesh = xs .HybridMesh (ici_mesh_shape = (data , 1 , model ))
418+ xs .mark_sharding (attn_output , data_model_mesh , (0 , 1 , 2 ))
419+ print (torch_xla ._XLAC ._get_xla_sharding_spec (attn_output ))
419420
420421 return attn_output , attn_weights , past_key_value
421422
@@ -1015,21 +1016,22 @@ def forward(
10151016 # embed positions
10161017 # Is this the input to the model?
10171018 hidden_states = inputs_embeds
1018- # Apply 2D sharding:
1019- # input (data,, None, model)
1020- import torch_xla .core .xla_model as xm
1021- import torch_xla .experimental .xla_sharding as xs
1022- import torch_xla .runtime as xr
1023- import torch_xla
1024- num_devices = xr .global_runtime_device_count ()
1025- device_ids = torch .arange (num_devices )
1026- print ('> Sharding hidden_states' , hidden_states .shape )
1027- model = self .spmd_2d_sharding
1028- data = num_devices // model
1029- assert model * data == num_devices
1030- data_model_mesh = xs .HybridMesh (ici_mesh_shape = (data , 1 , model ))
1031- xs .mark_sharding (hidden_states , data_model_mesh , (0 , 1 , 2 ))
1032- print (torch_xla ._XLAC ._get_xla_sharding_spec (hidden_states ))
1019+ if self .spmd_2d_sharding > 0 :
1020+ # Apply 2D sharding:
1021+ # input (data,, None, model)
1022+ import torch_xla .core .xla_model as xm
1023+ import torch_xla .experimental .xla_sharding as xs
1024+ import torch_xla .runtime as xr
1025+ import torch_xla
1026+ num_devices = xr .global_runtime_device_count ()
1027+ device_ids = torch .arange (num_devices )
1028+ print ('> Sharding hidden_states' , hidden_states .shape )
1029+ model = self .spmd_2d_sharding
1030+ data = num_devices // model
1031+ assert model * data == num_devices
1032+ data_model_mesh = xs .HybridMesh (ici_mesh_shape = (data , 1 , model ))
1033+ xs .mark_sharding (hidden_states , data_model_mesh , (0 , 1 , 2 ))
1034+ print (torch_xla ._XLAC ._get_xla_sharding_spec (hidden_states ))
10331035
10341036 # decoder layers
10351037 all_hidden_states = () if output_hidden_states else None
0 commit comments