@@ -1984,12 +1984,13 @@ def _prefetch( # noqa C901
19841984 # Store info for evicting the previous iteration's
19851985 # scratch pad after the corresponding backward pass is
19861986 # done
1987- self .ssd_location_update_data .append (
1988- (
1989- sp_curr_prev_map_gpu ,
1990- inserted_rows ,
1987+ if self .training :
1988+ self .ssd_location_update_data .append (
1989+ (
1990+ sp_curr_prev_map_gpu ,
1991+ inserted_rows ,
1992+ )
19911993 )
1992- )
19931994
19941995 # Ensure the previous iterations eviction is complete
19951996 current_stream .wait_event (self .ssd_event_sp_evict )
@@ -2173,7 +2174,7 @@ def _prefetch( # noqa C901
21732174
21742175 # Store scratch pad info for post backward eviction only for training
21752176 # for eval job, no backward pass, so no need to store this info
2176- if self .training and not self . _embedding_cache_mode :
2177+ if self .training :
21772178 self .ssd_scratch_pad_eviction_data .append (
21782179 (
21792180 inserted_rows ,
@@ -4548,6 +4549,12 @@ def direct_write_embedding(
45484549 if len (self .ssd_scratch_pad_eviction_data ) > 0 :
45494550 self .ssd_scratch_pad_eviction_data .pop (0 )
45504551 if len (self .ssd_scratch_pad_eviction_data ) > 0 :
4552+ # Wait for any pending backend reads to the next scratch pad
4553+ # to complete before we write to it. Otherwise, stale backend data
4554+ # will overwrite our direct_write updates.
4555+ # The ssd_event_get marks completion of backend fetch operations.
4556+ current_stream .wait_event (self .ssd_event_get )
4557+
45514558 # if scratch pad exists, write to next batch scratch pad
45524559 sp = self .ssd_scratch_pad_eviction_data [0 ][0 ]
45534560 sp_idx = self .ssd_scratch_pad_eviction_data [0 ][1 ].to (
0 commit comments