Skip to content

Commit 6a9e616

Browse files
emlinmeta-codesync[bot]
authored andcommitted
support prefetch pipeline (#5032)
Summary: Pull Request resolved: #5032 X-link: https://github.com/facebookresearch/FBGEMM/pull/2045 Fix the direct_write for prefetch pipeline Reviewed By: kausv, steven1327 Differential Revision: D85021220 fbshipit-source-id: 95348404db90e97e2884f9a366179541fef79e3f
1 parent 51210b8 commit 6a9e616

File tree

2 files changed

+472
-7
lines changed

2 files changed

+472
-7
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,12 +1984,13 @@ def _prefetch( # noqa C901
19841984
# Store info for evicting the previous iteration's
19851985
# scratch pad after the corresponding backward pass is
19861986
# done
1987-
self.ssd_location_update_data.append(
1988-
(
1989-
sp_curr_prev_map_gpu,
1990-
inserted_rows,
1987+
if self.training:
1988+
self.ssd_location_update_data.append(
1989+
(
1990+
sp_curr_prev_map_gpu,
1991+
inserted_rows,
1992+
)
19911993
)
1992-
)
19931994

19941995
# Ensure the previous iterations eviction is complete
19951996
current_stream.wait_event(self.ssd_event_sp_evict)
@@ -2173,7 +2174,7 @@ def _prefetch( # noqa C901
21732174

21742175
# Store scratch pad info for post backward eviction only for training
21752176
# for eval job, no backward pass, so no need to store this info
2176-
if self.training and not self._embedding_cache_mode:
2177+
if self.training:
21772178
self.ssd_scratch_pad_eviction_data.append(
21782179
(
21792180
inserted_rows,
@@ -4548,6 +4549,12 @@ def direct_write_embedding(
45484549
if len(self.ssd_scratch_pad_eviction_data) > 0:
45494550
self.ssd_scratch_pad_eviction_data.pop(0)
45504551
if len(self.ssd_scratch_pad_eviction_data) > 0:
4552+
# Wait for any pending backend reads to the next scratch pad
4553+
# to complete before we write to it. Otherwise, stale backend data
4554+
# will overwrite our direct_write updates.
4555+
# The ssd_event_get marks completion of backend fetch operations.
4556+
current_stream.wait_event(self.ssd_event_get)
4557+
45514558
# if scratch pad exists, write to next batch scratch pad
45524559
sp = self.ssd_scratch_pad_eviction_data[0][0]
45534560
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(

0 commit comments

Comments
 (0)