Skip to content

Commit f914e75

Browse files
committed
Fix async checkpoint timing in DCP recipe
Move checkpoint_future.result() before optimizer.step() to ensure the previous checkpoint completes before weights are modified in-place. This allows better overlap of checkpointing with forward/backward passes. Fixes #3584
1 parent 7f8b6dc commit f914e75

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

recipes_source/distributed_async_checkpoint_recipe.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,14 @@ checkpoint requests users can take advantage of direct memory access to speed up
257257
for step in range(10):
258258
optimizer.zero_grad()
259259
model(torch.rand(8, 16, device="cuda")).sum().backward()
260-
optimizer.step()
261260
262-
state_dict = { "app": AppState(model, optimizer) }
261+
# Wait for the previous checkpoint to finish before optimizer.step() modifies weights in-place
263262
if checkpoint_future is not None:
264-
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
265263
checkpoint_future.result()
264+
265+
optimizer.step()
266+
267+
state_dict = { "app": AppState(model, optimizer) }
266268
checkpoint_future = dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
267269
268270
cleanup()

0 commit comments

Comments
 (0)