File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
torchrec/distributed/test_utils Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -304,6 +304,7 @@ def generate_variable_batch_input(
304
304
strides_per_rank_per_feature = {}
305
305
inverse_indices_per_rank_per_feature = {}
306
306
label_per_rank = []
307
+
307
308
for rank in range (world_size ):
308
309
# keys, values, lengths, strides
309
310
lengths_per_rank_per_feature [rank ] = {}
@@ -375,12 +376,11 @@ def generate_variable_batch_input(
375
376
accum_batch_size = 0
376
377
inverse_indices = []
377
378
for rank in range (world_size ):
378
- inverse_indices += [
379
- index + accum_batch_size
380
- for index in inverse_indices_per_rank_per_feature [rank ][key ]
381
- ]
379
+ inverse_indices .append (
380
+ inverse_indices_per_rank_per_feature [rank ][key ] + accum_batch_size
381
+ )
382
382
accum_batch_size += strides_per_rank_per_feature [rank ][key ]
383
- inverse_indices_list .append (torch .IntTensor (inverse_indices ))
383
+ inverse_indices_list .append (torch .cat (inverse_indices ))
384
384
global_inverse_indices = (list (keys .keys ()), torch .stack (inverse_indices_list ))
385
385
if global_constant_batch :
386
386
global_offsets = []
You can’t perform that action at this time.
0 commit comments