diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index d8554edea..6fa024334 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -338,6 +338,7 @@ def __init__( features_to_permute_indices: Optional[Dict[str, List[int]]] = None, module_fqn: Optional[str] = None, sharding_types: Optional[List[str]] = None, + use_gather_select: bool = False, ) -> None: super().__init__() self._awaitables_per_sharding = awaitables_per_sharding @@ -348,6 +349,7 @@ def __init__( self._ctx = ctx self._module_fqn = module_fqn self._sharding_types = sharding_types + self._use_gather_select = use_gather_select def _wait_impl(self) -> Dict[str, JaggedTensor]: jt_dict: Dict[str, JaggedTensor] = {} @@ -389,6 +391,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]: original_features=original_features, reverse_indices=reverse_indices, seq_vbe_ctx=seq_vbe_ctx, + use_gather_select=self._use_gather_select, ) ) return jt_dict @@ -529,6 +532,7 @@ def __init__( module.embedding_configs(), table_name_to_parameter_sharding ) self._need_indices: bool = module.need_indices() + self._use_gather_select: bool = module.use_gather_select() self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None self._skip_missing_weight_key: List[str] = [] @@ -1563,6 +1567,7 @@ def output_dist( need_indices=self._need_indices, features_to_permute_indices=self._features_to_permute_indices, ctx=ctx, + use_gather_select=self._use_gather_select, ) def compute_and_output_dist( @@ -1612,6 +1617,7 @@ def compute_and_output_dist( ctx=ctx, module_fqn=self._module_fqn, sharding_types=list(self._sharding_type_to_sharding.keys()), + use_gather_select=self._use_gather_select, ) def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int: diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 00726e3c2..99b4d0392 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -408,6 +408,7 @@ def __init__( # noqa C901 tables: List[EmbeddingConfig], device: Optional[torch.device] = None, need_indices: bool = False, + use_gather_select: bool = False, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") @@ -415,6 +416,7 @@ def __init__( # noqa C901 self._embedding_configs = tables self._embedding_dim: int = -1 self._need_indices: bool = need_indices + self._use_gather_select: bool = use_gather_select self._device: torch.device = ( device if device is not None else torch.device("cpu") ) @@ -541,3 +543,10 @@ def reset_parameters(self) -> None: param = self.embeddings[f"{table_config.name}"].weight # pyre-ignore table_config.init_fn(param) + + def use_gather_select(self) -> bool: + """ + Returns: + bool: Whether the EmbeddingCollection uses torch.gather to select embeddings. + """ + return self._use_gather_select diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 2d6f4b4a5..95b082b4a 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -245,14 +245,22 @@ def construct_jagged_tensors( original_features: Optional[KeyedJaggedTensor] = None, reverse_indices: Optional[torch.Tensor] = None, seq_vbe_ctx: Optional[SequenceVBEContext] = None, + use_gather_select: bool = False, ) -> Dict[str, JaggedTensor]: with record_function("## construct_jagged_tensors ##"): if original_features is not None: features = original_features if reverse_indices is not None: - embeddings = torch.index_select( - embeddings, 0, reverse_indices.to(torch.int32) - ) + if use_gather_select: + # gather has better backward performance than index_select in many cases + expanded_indices = reverse_indices.unsqueeze(1).expand( + -1, embeddings.size(-1) + ) + embeddings = torch.gather(embeddings, 0, expanded_indices) + else: + embeddings = torch.index_select( + embeddings, 0, reverse_indices.to(torch.int32) + ) ret: Dict[str, JaggedTensor] = {} if seq_vbe_ctx is not None: