diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 89215cafd..4bc70d180 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -304,6 +304,7 @@ def generate_variable_batch_input( strides_per_rank_per_feature = {} inverse_indices_per_rank_per_feature = {} label_per_rank = [] + for rank in range(world_size): # keys, values, lengths, strides lengths_per_rank_per_feature[rank] = {} @@ -375,12 +376,11 @@ def generate_variable_batch_input( accum_batch_size = 0 inverse_indices = [] for rank in range(world_size): - inverse_indices += [ - index + accum_batch_size - for index in inverse_indices_per_rank_per_feature[rank][key] - ] + inverse_indices.append( + inverse_indices_per_rank_per_feature[rank][key] + accum_batch_size + ) accum_batch_size += strides_per_rank_per_feature[rank][key] - inverse_indices_list.append(torch.IntTensor(inverse_indices)) + inverse_indices_list.append(torch.cat(inverse_indices)) global_inverse_indices = (list(keys.keys()), torch.stack(inverse_indices_list)) if global_constant_batch: global_offsets = []