@@ -613,7 +613,7 @@ def test_inplace_add_with_sharding(self):
613613 self .assertEqual (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
614614 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xt ])
615615 self .assertIn (
616- '%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
616+ '%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
617617 hlo )
618618
619619 # avoid calling xr.addressable_device_count here otherwise it will init the test
@@ -713,7 +713,8 @@ def test_xla_sharded_hlo_dump(self):
713713 partition_spec )
714714 xst2 = xst1 + 5
715715 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xst2 .global_tensor ])
716- self .assertIn ('%p1.3 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
716+ print (hlo )
717+ self .assertIn ('%p1.1 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
717718 if torch_xla ._XLAC ._xla_get_auto_sharding ():
718719 # scalar 5 should be implicitly replicated, so the pre-optimization HLO
719720 # shouldn't mark it with sharding.
@@ -828,13 +829,13 @@ def test_mark_sharding_ir(self):
828829 (0 , 1 ))
829830 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
830831 self .assertIn (
831- '%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
832+ '%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
832833 hlo )
833834
834835 actual += 0
835836 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
836837 self .assertIn (
837- '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9 , f32[1,128]{1,0} %broadcast.11 )' ,
838+ '%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1 , f32[1,128]{1,0} %broadcast.3 )' ,
838839 hlo )
839840
840841 self .assertTrue (torch .allclose (expected , actual .cpu ()))
@@ -1141,7 +1142,7 @@ def test_backward_optimization_barrier(self):
11411142
11421143 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([model .fc2 .weight .grad ])
11431144 self .assertIn (
1144- '%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36 )' ,
1145+ '%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2 )' ,
11451146 hlo )
11461147
11471148 def test_mark_shard_scalar (self ):
@@ -1198,7 +1199,7 @@ def test_spmd_full_to_shard_shape(self):
11981199
11991200 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12001201 self .assertEqual (xx .shape , (8 , 8 // self .n_devices ))
1201- self .assertIn (f'%custom-call.2 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
1202+ self .assertIn (f'%custom-call.1 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
12021203 self .assertIn (
12031204 f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
12041205 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1215,7 +1216,7 @@ def test_spmd_full_to_shard_shape(self):
12151216
12161217 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12171218 self .assertEqual (xx .shape , (8 , 4 ))
1218- self .assertIn (f'%custom-call.2 = f32[8,4]{{1,0}}' , hlo )
1219+ self .assertIn (f'%custom-call.1 = f32[8,4]{{1,0}}' , hlo )
12191220 self .assertIn (
12201221 f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
12211222 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1246,7 +1247,7 @@ def test_spmd_shard_to_full_shape(self):
12461247
12471248 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
12481249 self .assertEqual (xx .shape , x .shape )
1249- self .assertIn ('%custom-call.9 = f32[8,8]{1,0}' , hlo )
1250+ self .assertIn ('%custom-call.5 = f32[8,8]{1,0}' , hlo )
12501251 self .assertIn (
12511252 'custom_call_target="SPMDShardToFullShape", sharding={replicated}' , hlo )
12521253 self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{replicated}" )
@@ -1297,7 +1298,7 @@ def test_spmd_reduce_scatter(self):
12971298
12981299 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
12991300 self .assertIn (
1300- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3 " ,
1301+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1 " ,
13011302 hlo )
13021303
13031304 expected_x = torch .ones (8 // self .n_devices , 8 ) * self .n_devices
@@ -1318,7 +1319,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
13181319
13191320 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13201321 self .assertIn (
1321- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3 " ,
1322+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1 " ,
13221323 hlo )
13231324
13241325 expected_x = torch .ones (8 , 8 // self .n_devices ) * self .n_devices
@@ -1338,7 +1339,7 @@ def test_spmd_all_reduce(self):
13381339
13391340 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13401341 self .assertIn (
1341- f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
1342+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
13421343 hlo )
13431344
13441345 expected_x = torch .ones (8 , 8 ) * self .n_devices
@@ -1359,7 +1360,7 @@ def test_spmd_all_reduce_scale(self):
13591360
13601361 hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
13611362 self .assertIn (
1362- f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
1363+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
13631364 hlo )
13641365
13651366 expected_x = torch .ones (8 , 8 ) * int (self .n_devices * scale )
@@ -1713,7 +1714,7 @@ def test_annotate_custom_sharding(self):
17131714 f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={ original_sharding_spec } ' ,
17141715 hlo )
17151716 self .assertIn (
1716- f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
1717+ f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
17171718 hlo )
17181719 xm .mark_step ()
17191720 # Ensure that the resulting sharding spec is preserved
0 commit comments