Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
module_fqn: Optional[str] = None,
sharding_types: Optional[List[str]] = None,
use_gather_select: bool = False,
) -> None:
super().__init__()
self._awaitables_per_sharding = awaitables_per_sharding
Expand All @@ -348,6 +349,7 @@ def __init__(
self._ctx = ctx
self._module_fqn = module_fqn
self._sharding_types = sharding_types
self._use_gather_select = use_gather_select

def _wait_impl(self) -> Dict[str, JaggedTensor]:
jt_dict: Dict[str, JaggedTensor] = {}
Expand Down Expand Up @@ -389,6 +391,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
original_features=original_features,
reverse_indices=reverse_indices,
seq_vbe_ctx=seq_vbe_ctx,
use_gather_select=self._use_gather_select,
)
)
return jt_dict
Expand Down Expand Up @@ -529,6 +532,7 @@ def __init__(
module.embedding_configs(), table_name_to_parameter_sharding
)
self._need_indices: bool = module.need_indices()
self._use_gather_select: bool = module.use_gather_select()
self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None
self._skip_missing_weight_key: List[str] = []

Expand Down Expand Up @@ -1563,6 +1567,7 @@ def output_dist(
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
ctx=ctx,
use_gather_select=self._use_gather_select,
)

def compute_and_output_dist(
Expand Down Expand Up @@ -1612,6 +1617,7 @@ def compute_and_output_dist(
ctx=ctx,
module_fqn=self._module_fqn,
sharding_types=list(self._sharding_type_to_sharding.keys()),
use_gather_select=self._use_gather_select,
)

def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
Expand Down
9 changes: 9 additions & 0 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,15 @@ def __init__( # noqa C901
tables: List[EmbeddingConfig],
device: Optional[torch.device] = None,
need_indices: bool = False,
use_gather_select: bool = False,
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
self.embeddings: nn.ModuleDict = nn.ModuleDict()
self._embedding_configs = tables
self._embedding_dim: int = -1
self._need_indices: bool = need_indices
self._use_gather_select: bool = use_gather_select
self._device: torch.device = (
device if device is not None else torch.device("cpu")
)
Expand Down Expand Up @@ -541,3 +543,10 @@ def reset_parameters(self) -> None:
param = self.embeddings[f"{table_config.name}"].weight
# pyre-ignore
table_config.init_fn(param)

def use_gather_select(self) -> bool:
"""
Returns:
bool: Whether the EmbeddingCollection uses torch.gather to select embeddings.
"""
return self._use_gather_select
14 changes: 11 additions & 3 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,22 @@ def construct_jagged_tensors(
original_features: Optional[KeyedJaggedTensor] = None,
reverse_indices: Optional[torch.Tensor] = None,
seq_vbe_ctx: Optional[SequenceVBEContext] = None,
use_gather_select: bool = False,
) -> Dict[str, JaggedTensor]:
with record_function("## construct_jagged_tensors ##"):
if original_features is not None:
features = original_features
if reverse_indices is not None:
embeddings = torch.index_select(
embeddings, 0, reverse_indices.to(torch.int32)
)
if use_gather_select:
# gather has better backward performance than index_select in many cases
expanded_indices = reverse_indices.unsqueeze(1).expand(
-1, embeddings.size(-1)
)
embeddings = torch.gather(embeddings, 0, expanded_indices)
else:
embeddings = torch.index_select(
embeddings, 0, reverse_indices.to(torch.int32)
)
ret: Dict[str, JaggedTensor] = {}

if seq_vbe_ctx is not None:
Expand Down
Loading