@@ -269,19 +269,32 @@ def get_inputs(
269
269
num_inputs : int ,
270
270
train : bool ,
271
271
pooling_configs : Optional [List [int ]] = None ,
272
+ variable_batch_embeddings : bool = False ,
272
273
) -> List [List [KeyedJaggedTensor ]]:
273
274
inputs_batch : List [List [KeyedJaggedTensor ]] = []
274
275
276
+ if variable_batch_embeddings and not train :
277
+ raise RuntimeError ("Variable batch size is only supported in training mode" )
278
+
275
279
for _ in range (num_inputs ):
276
- _ , model_input_by_rank = ModelInput .generate (
277
- batch_size = batch_size ,
278
- world_size = world_size ,
279
- num_float_features = 0 ,
280
- tables = tables ,
281
- weighted_tables = [],
282
- long_indices = False ,
283
- tables_pooling = pooling_configs ,
284
- )
280
+ if variable_batch_embeddings :
281
+ _ , model_input_by_rank = ModelInput .generate_variable_batch_input (
282
+ average_batch_size = batch_size ,
283
+ world_size = world_size ,
284
+ num_float_features = 0 ,
285
+ # pyre-ignore
286
+ tables = tables ,
287
+ )
288
+ else :
289
+ _ , model_input_by_rank = ModelInput .generate (
290
+ batch_size = batch_size ,
291
+ world_size = world_size ,
292
+ num_float_features = 0 ,
293
+ tables = tables ,
294
+ weighted_tables = [],
295
+ long_indices = False ,
296
+ tables_pooling = pooling_configs ,
297
+ )
285
298
286
299
if train :
287
300
sparse_features_by_rank = [
@@ -770,6 +783,7 @@ def benchmark_module(
770
783
func_to_benchmark : Callable [..., None ] = default_func_to_benchmark ,
771
784
benchmark_func_kwargs : Optional [Dict [str , Any ]] = None ,
772
785
pooling_configs : Optional [List [int ]] = None ,
786
+ variable_batch_embeddings : bool = False ,
773
787
) -> List [BenchmarkResult ]:
774
788
"""
775
789
Args:
@@ -820,7 +834,13 @@ def benchmark_module(
820
834
821
835
num_inputs_to_gen : int = warmup_iters + bench_iters + prof_iters
822
836
inputs = get_inputs (
823
- tables , batch_size , world_size , num_inputs_to_gen , train , pooling_configs
837
+ tables ,
838
+ batch_size ,
839
+ world_size ,
840
+ num_inputs_to_gen ,
841
+ train ,
842
+ pooling_configs ,
843
+ variable_batch_embeddings ,
824
844
)
825
845
826
846
warmup_inputs = [rank_inputs [:warmup_iters ] for rank_inputs in inputs ]
0 commit comments