Skip to content

Commit a07bc63

Browse files
Jingchang Zhangfacebook-github-bot
authored andcommitted
Add a option to move embedding lookup after sparse data dist in FusedSDD (#3345)
Summary: Pull Request resolved: #3345 This diff adds an option to allows the embedding lookup trigger after the sparse data dist. This can potentially improve performance when CPU is blocked by sparse data dist kernel launch and could not launch forward kernel earlier. {F1981658737,width=300} Reviewed By: TroyGarden Differential Revision: D81494775 fbshipit-source-id: 2d975d6784b806edbcdfdf44a181632570e2c940
1 parent 7b1ea84 commit a07bc63

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,7 @@ def __init__(
964964
] = None,
965965
strict: bool = False,
966966
emb_lookup_stream: str = "data_dist", # new, current, data_dist (default)
967+
embedding_lookup_after_data_dist: bool = False,
967968
) -> None:
968969
super().__init__(
969970
model=model,
@@ -975,6 +976,8 @@ def __init__(
975976
pipeline_postproc=pipeline_postproc,
976977
custom_model_fwd=custom_model_fwd,
977978
)
979+
self._embedding_lookup_after_data_dist = embedding_lookup_after_data_dist
980+
978981
if emb_lookup_stream == "new":
979982
self._emb_lookup_stream: Optional[torch.Stream] = (
980983
(torch.get_device_module(device).Stream())
@@ -1046,8 +1049,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
10461049
self._set_module_context(self.contexts[0])
10471050

10481051
# start embedding_lookup so it can overlap with previous optimizer
1049-
# pyre-ignore [6]
1050-
self.start_embedding_lookup(self.batches[0], self.contexts[0])
1052+
if not self._embedding_lookup_after_data_dist:
1053+
# pyre-ignore [6]
1054+
self.start_embedding_lookup(self.batches[0], self.contexts[0])
10511055

10521056
if self._model.training:
10531057
with record_function("## zero_grad ##"):
@@ -1064,6 +1068,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
10641068
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
10651069
self.enqueue_batch(dataloader_iter)
10661070

1071+
if self._embedding_lookup_after_data_dist:
1072+
# pyre-ignore [6]
1073+
self.start_embedding_lookup(self.batches[0], self.contexts[0])
1074+
10671075
# forward
10681076
with record_function(f"## forward {self.contexts[0].index} ##"):
10691077
losses, output = self._model_fwd(self.batches[0])

0 commit comments

Comments
 (0)