Skip to content

Commit 5458b03

Browse files
Jingchang Zhangmeta-codesync[bot]
authored andcommitted
Add option to use gather to select embeddings in EC (#3479)
Summary: Pull Request resolved: #3479 Due to atomic add in torch.index_select, the backward performance sometimes is bad comparing with gather. In this diff, it provides users with control over the indexing process and select the suitable operator based on specific cases. Perf comparison on pure operators(forward+backward) 2D Embedding, No Repetition Config: shape=(1000000, 256), dim=0, indices=100000, unique=95300 (95.3%) Method Time (s) Speedup Status torch.gather 0.9439 1.00 x 🏆 torch.index_select 1.0509 0.90 x 2D Embedding, Low Repetition Config: shape=(1000000, 256), dim=0, indices=100000, unique=48732 (48.7%) Method Time (s) Speedup Status torch.gather 0.9076 1.00 x 🏆 torch.index_select 1.0415 0.87 x 2D Embedding, High Repetition Config: shape=(1000000, 256), dim=0, indices=250000, unique=9957 (4.0%) Method Time (s) Speedup Status torch.gather 1.2385 1.00 x 🏆 torch.index_select 1.6225 0.76 x Small Vocab, Low Repetition Config: shape=(1000, 256), dim=0, indices=2000, unique=635 (31.8%) Method Time (s) Speedup Status torch.gather 0.1502 1.00 x 🏆 torch.index_select 0.1763 0.85 x Small Vocab, Very High Repetition Config: shape=(1000, 256), dim=0, indices=100000, unique=625 (0.6%) Method Time (s) Speedup Status torch.gather 0.2626 1.00 x 🏆 torch.index_select 0.4126 0.64 x Large Vocab, No Repetition Config: shape=(10000000, 256), dim=0, indices=10000, unique=9996 (100.0%) Method Time (s) Speedup Status torch.gather 5.8014 1.00 x 🏆 torch.index_select 5.8184 1.00 x Large Vocab, Low Repetition Config: shape=(10000000, 256), dim=0, indices=10000, unique=5000 (50.0%) Method Time (s) Speedup Status torch.gather 5.7912 1.00 x 🏆 torch.index_select 5.8137 1.00 x Large Vocab, High Repetition Config: shape=(10000000, 256), dim=0, indices=10000, unique=400 (4.0%) Method Time (s) Speedup Status torch.gather 5.7784 1.00 x 🏆 torch.index_select 5.8100 0.99 x Mast Job Test: baseline: fire-jingchang-f816557933 torch.index_select backward takes ~37ms {F1982939713} exp: fire-jingchang-f816355728 torch.gather backward takes ~10ms {F1982939742} Reviewed By: TroyGarden Differential Revision: D85309309
1 parent 7ddc21d commit 5458b03

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

torchrec/distributed/embedding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def __init__(
338338
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
339339
module_fqn: Optional[str] = None,
340340
sharding_types: Optional[List[str]] = None,
341+
use_gather_select: bool = False,
341342
) -> None:
342343
super().__init__()
343344
self._awaitables_per_sharding = awaitables_per_sharding
@@ -348,6 +349,7 @@ def __init__(
348349
self._ctx = ctx
349350
self._module_fqn = module_fqn
350351
self._sharding_types = sharding_types
352+
self._use_gather_select = use_gather_select
351353

352354
def _wait_impl(self) -> Dict[str, JaggedTensor]:
353355
jt_dict: Dict[str, JaggedTensor] = {}
@@ -389,6 +391,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
389391
original_features=original_features,
390392
reverse_indices=reverse_indices,
391393
seq_vbe_ctx=seq_vbe_ctx,
394+
use_gather_select=self._use_gather_select,
392395
)
393396
)
394397
return jt_dict
@@ -529,6 +532,7 @@ def __init__(
529532
module.embedding_configs(), table_name_to_parameter_sharding
530533
)
531534
self._need_indices: bool = module.need_indices()
535+
self._use_gather_select: bool = module.use_gather_select()
532536
self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None
533537
self._skip_missing_weight_key: List[str] = []
534538

@@ -1563,6 +1567,7 @@ def output_dist(
15631567
need_indices=self._need_indices,
15641568
features_to_permute_indices=self._features_to_permute_indices,
15651569
ctx=ctx,
1570+
use_gather_select=self._use_gather_select,
15661571
)
15671572

15681573
def compute_and_output_dist(
@@ -1612,6 +1617,7 @@ def compute_and_output_dist(
16121617
ctx=ctx,
16131618
module_fqn=self._module_fqn,
16141619
sharding_types=list(self._sharding_type_to_sharding.keys()),
1620+
use_gather_select=self._use_gather_select,
16151621
)
16161622

16171623
def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:

torchrec/modules/embedding_modules.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,13 +408,15 @@ def __init__( # noqa C901
408408
tables: List[EmbeddingConfig],
409409
device: Optional[torch.device] = None,
410410
need_indices: bool = False,
411+
use_gather_select: bool = False,
411412
) -> None:
412413
super().__init__()
413414
torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}")
414415
self.embeddings: nn.ModuleDict = nn.ModuleDict()
415416
self._embedding_configs = tables
416417
self._embedding_dim: int = -1
417418
self._need_indices: bool = need_indices
419+
self._use_gather_select: bool = use_gather_select
418420
self._device: torch.device = (
419421
device if device is not None else torch.device("cpu")
420422
)
@@ -541,3 +543,10 @@ def reset_parameters(self) -> None:
541543
param = self.embeddings[f"{table_config.name}"].weight
542544
# pyre-ignore
543545
table_config.init_fn(param)
546+
547+
def use_gather_select(self) -> bool:
548+
"""
549+
Returns:
550+
bool: Whether the EmbeddingCollection uses torch.gather to select embeddings.
551+
"""
552+
return self._use_gather_select

torchrec/modules/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,22 @@ def construct_jagged_tensors(
245245
original_features: Optional[KeyedJaggedTensor] = None,
246246
reverse_indices: Optional[torch.Tensor] = None,
247247
seq_vbe_ctx: Optional[SequenceVBEContext] = None,
248+
use_gather_select: bool = False,
248249
) -> Dict[str, JaggedTensor]:
249250
with record_function("## construct_jagged_tensors ##"):
250251
if original_features is not None:
251252
features = original_features
252253
if reverse_indices is not None:
253-
embeddings = torch.index_select(
254-
embeddings, 0, reverse_indices.to(torch.int32)
255-
)
254+
if use_gather_select:
255+
# gather has better backward performance than index_select in many cases
256+
expanded_indices = reverse_indices.unsqueeze(1).expand(
257+
-1, embeddings.size(-1)
258+
)
259+
embeddings = torch.gather(embeddings, 0, expanded_indices)
260+
else:
261+
embeddings = torch.index_select(
262+
embeddings, 0, reverse_indices.to(torch.int32)
263+
)
256264
ret: Dict[str, JaggedTensor] = {}
257265

258266
if seq_vbe_ctx is not None:

0 commit comments

Comments
 (0)