From 4873bdd1db7621cd611598901669cf7c7f068b8b Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Mon, 20 May 2024 06:30:30 -0700 Subject: [PATCH] Use non-variable stride per rank path for dynamo (#2018) Summary: Pick non-variable stride per rank path for dynamo for now. In future that will be solved with first batch path memorizaiton. Differential Revision: D57562279 --- torchrec/distributed/dist_data.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index ff55b0e19..b777278fb 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -186,9 +186,13 @@ def _get_recat_tensor_compute( + (feature_order.expand(LS, FO_S0) * LS) ).reshape(-1) - vb_condition = batch_size_per_rank is not None and any( - bs != batch_size_per_rank[0] for bs in batch_size_per_rank - ) + # Use non variable stride per rank path for dynamo + # TODO(ivankobzarev): Implement memorization of the path from the first batch. + vb_condition = False + if not is_torchdynamo_compiling(): + vb_condition = batch_size_per_rank is not None and any( + bs != batch_size_per_rank[0] for bs in batch_size_per_rank + ) if vb_condition: batch_size_per_rank_tensor = torch._refs.tensor(