Skip to content

Commit 08fcd6f

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
VBE training benchmarks (Manual) (#1855)
Summary: Set TorchRec's distributed training benchmarks to include VBE. Differential Revision: D55882022
1 parent e02c9e5 commit 08fcd6f

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def benchmark_ebc(
7171
args: argparse.Namespace,
7272
output_dir: str,
7373
pooling_configs: Optional[List[int]] = None,
74+
variable_batch_embeddings: bool = False,
7475
) -> List[BenchmarkResult]:
7576
table_configs = get_tables(tables, data_type=DataType.FP32)
7677
sharder = TestEBCSharder(
@@ -104,6 +105,9 @@ def benchmark_ebc(
104105
if pooling_configs:
105106
args_kwargs["pooling_configs"] = pooling_configs
106107

108+
if variable_batch_embeddings:
109+
args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings
110+
107111
return benchmark_module(
108112
module=module,
109113
sharder=sharder,
@@ -153,6 +157,7 @@ def main() -> None:
153157
mb = int(float(num * dim) / 1024 / 1024) * 4
154158
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb"
155159

160+
### Benchmark no VBE
156161
report: str = (
157162
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
158163
)
@@ -176,6 +181,27 @@ def main() -> None:
176181
)
177182
)
178183

184+
### Benchmark with VBE
185+
report: str = (
186+
f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
187+
)
188+
report += f"Module: {module_name} (VBE)\n"
189+
report += tables_info
190+
report += "\n"
191+
report_file = f"{output_dir}/run_vbe.report"
192+
193+
benchmark_results_per_module.append(
194+
benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs, True)
195+
)
196+
write_report_funcs_per_module.append(
197+
partial(
198+
write_report,
199+
report_file=report_file,
200+
report_str=report,
201+
num_requests=num_requests,
202+
)
203+
)
204+
179205
for i, write_report_func in enumerate(write_report_funcs_per_module):
180206
write_report_func(benchmark_results_per_module[i])
181207

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,32 @@ def get_inputs(
269269
num_inputs: int,
270270
train: bool,
271271
pooling_configs: Optional[List[int]] = None,
272+
variable_batch_embeddings: bool = False,
272273
) -> List[List[KeyedJaggedTensor]]:
273274
inputs_batch: List[List[KeyedJaggedTensor]] = []
274275

276+
if variable_batch_embeddings and not train:
277+
raise RuntimeError("Variable batch size is only supported in training mode")
278+
275279
for _ in range(num_inputs):
276-
_, model_input_by_rank = ModelInput.generate(
277-
batch_size=batch_size,
278-
world_size=world_size,
279-
num_float_features=0,
280-
tables=tables,
281-
weighted_tables=[],
282-
long_indices=False,
283-
tables_pooling=pooling_configs,
284-
)
280+
if variable_batch_embeddings:
281+
_, model_input_by_rank = ModelInput.generate_variable_batch_input(
282+
average_batch_size=batch_size,
283+
world_size=world_size,
284+
num_float_features=0,
285+
# pyre-ignore
286+
tables=tables,
287+
)
288+
else:
289+
_, model_input_by_rank = ModelInput.generate(
290+
batch_size=batch_size,
291+
world_size=world_size,
292+
num_float_features=0,
293+
tables=tables,
294+
weighted_tables=[],
295+
long_indices=False,
296+
tables_pooling=pooling_configs,
297+
)
285298

286299
if train:
287300
sparse_features_by_rank = [
@@ -770,6 +783,7 @@ def benchmark_module(
770783
func_to_benchmark: Callable[..., None] = default_func_to_benchmark,
771784
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
772785
pooling_configs: Optional[List[int]] = None,
786+
variable_batch_embeddings: bool = False,
773787
) -> List[BenchmarkResult]:
774788
"""
775789
Args:
@@ -820,7 +834,13 @@ def benchmark_module(
820834

821835
num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
822836
inputs = get_inputs(
823-
tables, batch_size, world_size, num_inputs_to_gen, train, pooling_configs
837+
tables,
838+
batch_size,
839+
world_size,
840+
num_inputs_to_gen,
841+
train,
842+
pooling_configs,
843+
variable_batch_embeddings,
824844
)
825845

826846
warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs]

0 commit comments

Comments
 (0)