@@ -57,24 +57,25 @@ def visualize_sharding(sharding: str,
5757 # eg: '{devices=[2,2]0,1,2,3}'
5858 # eg: '{replicated}'
5959 # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}'
60+ print (f"Visualizing { sharding } (showing up to the first two dimensions)" )
6061 if sharding == '{replicated}' or len (sharding ) == 0 :
6162 heights = 1
6263 widths = 1
6364 num_devices = xr .global_runtime_device_count ()
6465 device_ids = list (range (num_devices ))
6566 slices .setdefault ((0 , 0 ), device_ids )
6667 else :
67- sharding_spac = sharding [sharding .index ('[' ):sharding .index (']' ) + 1 ]
68+ sharding_spac = sharding [sharding .index ('[' ) + 1 :sharding .index (']' )]. split ( "," )
6869 device_list_original = sharding .split (' last_tile_dim_replicate' )
6970 if len (device_list_original ) == 2 and device_list_original [1 ] == '}' :
7071 try :
7172 device_list_original_first = device_list_original [0 ]
7273 device_list = device_list_original_first [device_list_original_first .
7374 index (']' ) + 1 :]
7475 device_indices_map = [int (s ) for s in device_list .split (',' )]
75- heights = int (sharding_spac [1 ])
76- widths = int (sharding_spac [3 ])
77- last_dim_depth = int (sharding_spac [5 ])
76+ heights = int (sharding_spac [0 ])
77+ widths = int (sharding_spac [1 ])
78+ last_dim_depth = int (sharding_spac [- 1 ])
7879 devices_len = len (device_indices_map )
7980 len_after_dim_down = devices_len // last_dim_depth
8081 for i in range (len_after_dim_down ):
@@ -96,8 +97,8 @@ def visualize_sharding(sharding: str,
9697 device_list = device_list_original_first [device_list_original_first .
9798 index (']' ) + 1 :- 1 ]
9899 device_indices_map = [int (i ) for i in device_list .split (',' )]
99- heights = int (sharding_spac [1 ])
100- widths = int (sharding_spac [3 ])
100+ heights = int (sharding_spac [0 ])
101+ widths = int (sharding_spac [1 ])
101102 devices_len = len (device_indices_map )
102103 for i in range (devices_len ):
103104 slices .setdefault ((i // widths , i % widths ), device_indices_map [i ])
0 commit comments