Skip to content

Commit e02c9e5

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Optimize VBE input generation (#1854)
Summary: While authoring VBE benchmarks, this code block was very inefficient as determined by the profiler: https://www.internalfb.com/fburl?nopassthru=1&key=scuba%2Fpyperf_experimental%2Fon_demand%2Fnbkvd0xv. This diff optimizes the code by vectorizing the addition and appending to the list Differential Revision: D55882021
1 parent f43660b commit e02c9e5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def generate_variable_batch_input(
304304
strides_per_rank_per_feature = {}
305305
inverse_indices_per_rank_per_feature = {}
306306
label_per_rank = []
307+
307308
for rank in range(world_size):
308309
# keys, values, lengths, strides
309310
lengths_per_rank_per_feature[rank] = {}
@@ -375,12 +376,11 @@ def generate_variable_batch_input(
375376
accum_batch_size = 0
376377
inverse_indices = []
377378
for rank in range(world_size):
378-
inverse_indices += [
379-
index + accum_batch_size
380-
for index in inverse_indices_per_rank_per_feature[rank][key]
381-
]
379+
inverse_indices.append(
380+
inverse_indices_per_rank_per_feature[rank][key] + accum_batch_size
381+
)
382382
accum_batch_size += strides_per_rank_per_feature[rank][key]
383-
inverse_indices_list.append(torch.IntTensor(inverse_indices))
383+
inverse_indices_list.append(torch.cat(inverse_indices))
384384
global_inverse_indices = (list(keys.keys()), torch.stack(inverse_indices_list))
385385
if global_constant_batch:
386386
global_offsets = []

0 commit comments

Comments
 (0)