@@ -638,7 +638,7 @@ def test_inplace_add_with_sharding(self):
638638 self .assertEqual (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
639639 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xt ])
640640 self .assertIn (
641- '%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
641+ '%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
642642 hlo )
643643
644644 # avoid calling xr.addressable_device_count here otherwise it will init the test
@@ -738,8 +738,7 @@ def test_xla_sharded_hlo_dump(self):
738738 partition_spec )
739739 xst2 = xst1 + 5
740740 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xst2 .global_tensor ])
741- print (hlo )
742- self .assertIn ('%p1.1 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
741+ self .assertIn ('%p1.3 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
743742 if torch_xla ._XLAC ._xla_get_auto_sharding ():
744743 # scalar 5 should be implicitly replicated, so the pre-optimization HLO
745744 # shouldn't mark it with sharding.
@@ -854,13 +853,13 @@ def test_mark_sharding_ir(self):
854853 (0 , 1 ))
855854 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
856855 self .assertIn (
857- '%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
856+ '%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
858857 hlo )
859858
860859 actual += 0
861860 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
862861 self .assertIn (
863- '%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1 , f32[1,128]{1,0} %broadcast.3 )' ,
862+ '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9 , f32[1,128]{1,0} %broadcast.11 )' ,
864863 hlo )
865864
866865 self .assertTrue (torch .allclose (expected , actual .cpu ()))
@@ -1169,7 +1168,7 @@ def test_backward_optimization_barrier(self):
11691168
11701169 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([model .fc2 .weight .grad ])
11711170 self .assertIn (
1172- '%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2 )' ,
1171+ '%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36 )' ,
11731172 hlo )
11741173
11751174 def test_mark_shard_scalar (self ):
@@ -1226,7 +1225,7 @@ def test_spmd_full_to_shard_shape(self):
12261225
12271226 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12281227 self .assertEqual (xx .shape , (8 , 8 // self .n_devices ))
1229- self .assertIn (f'%custom-call.1 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
1228+ self .assertIn (f'%custom-call.2 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
12301229 self .assertIn (
12311230 f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
12321231 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1243,7 +1242,7 @@ def test_spmd_full_to_shard_shape(self):
12431242
12441243 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12451244 self .assertEqual (xx .shape , (8 , 4 ))
1246- self .assertIn (f'%custom-call.1 = f32[8,4]{{1,0}}' , hlo )
1245+ self .assertIn (f'%custom-call.2 = f32[8,4]{{1,0}}' , hlo )
12471246 self .assertIn (
12481247 f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
12491248 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1274,7 +1273,7 @@ def test_spmd_shard_to_full_shape(self):
12741273
12751274 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12761275 self .assertEqual (xx .shape , x .shape )
1277- self .assertIn ('%custom-call.5 = f32[8,8]{1,0}' , hlo )
1276+ self .assertIn ('%custom-call.9 = f32[8,8]{1,0}' , hlo )
12781277 self .assertIn (
12791278 'custom_call_target="SPMDShardToFullShape", sharding={replicated}' , hlo )
12801279 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{replicated}" )
@@ -1325,7 +1324,7 @@ def test_spmd_reduce_scatter(self):
13251324
13261325 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13271326 self .assertIn (
1328- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1 " ,
1327+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3 " ,
13291328 hlo )
13301329
13311330 expected_x = torch .ones (8 // self .n_devices , 8 ) * self .n_devices
@@ -1346,7 +1345,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
13461345
13471346 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13481347 self .assertIn (
1349- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1 " ,
1348+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3 " ,
13501349 hlo )
13511350
13521351 expected_x = torch .ones (8 , 8 // self .n_devices ) * self .n_devices
@@ -1366,7 +1365,7 @@ def test_spmd_all_reduce(self):
13661365
13671366 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13681367 self .assertIn (
1369- f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
1368+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
13701369 hlo )
13711370
13721371 expected_x = torch .ones (8 , 8 ) * self .n_devices
@@ -1387,7 +1386,7 @@ def test_spmd_all_reduce_scale(self):
13871386
13881387 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13891388 self .assertIn (
1390- f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
1389+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
13911390 hlo )
13921391
13931392 expected_x = torch .ones (8 , 8 ) * int (self .n_devices * scale )
@@ -1741,7 +1740,7 @@ def test_annotate_custom_sharding(self):
17411740 f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={ original_sharding_spec } ' ,
17421741 hlo )
17431742 self .assertIn (
1744- f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
1743+ f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
17451744 hlo )
17461745 xm .mark_step ()
17471746 # Ensure that the resulting sharding spec is preserved
0 commit comments