Skip to content

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Jan 22, 2026

Summary

Adds save_weights_for_sampler() to WorkerDispatch as the single entry point for syncing policy weights to inference engines before sampling. Now, calls to save_weights_for_sampler() are the trigger for weight sync, rather than explicit weight synchronization logic in the trainer. This aligns with the Tinker API pattern where users explicitly call save_weights_for_sampler() after training and before sampling.

Changes:

  • WorkerDispatch: Added save_weights_for_sampler() (to reflect the Tinker library) that handles the full weight sync flow:
    • Prepares GPU state (offloads optimizer, keeps model on GPU)
    • Wakes inference engine for weight transfer (colocate_all only)
    • Broadcasts weights to inference engines
    • Offloads model after sync (colocate_all only)
    • Wakes inference engine for KV cache (colocate_all only)
  • Trainer: Replaced two explicit weight sync blocks with dispatch.save_weights_for_sampler():
    • After checkpoint load (before first generation)
    • After each training step (before next generation)

Testing

Added test_save_weights_for_sampler tests, all passing.

Move weight sync logic into WorkerDispatch as a single entry point.
This aligns with Tinker's API pattern where users explicitly call
save_weights_for_sampler() after training and before sampling.

Changes:
- WorkerDispatch: Add save_weights_for_sampler() async method that:
  - Prepares GPU state (offloads optimizer, keeps model on GPU)
  - Wakes inference engine for weight transfer (colocate_all only)
  - Broadcasts weights to inference engines
  - Offloads model after sync (colocate_all only)
  - Wakes inference engine for KV cache (colocate_all only)
- Trainer: Replace two explicit weight sync blocks with:
  - dispatch.save_weights_for_sampler() after checkpoint load
  - dispatch.save_weights_for_sampler() after each training step
- Tests: Add test_save_weights_for_sampler.py with:
  - E2E test: train → sync → sample (colocate and non-colocate)
  - Multiple training steps before single sync

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly improves the codebase by encapsulating the weight synchronization logic into a single save_weights_for_sampler() method within WorkerDispatch. This refactoring greatly simplifies the Trainer class, making the training loop cleaner and more readable. Passing the inference_engine_client to the WorkerDispatch constructor is a good design choice for encapsulation. The addition of comprehensive end-to-end tests for the new functionality is also a major plus, ensuring the changes are robust. I have a couple of suggestions to improve consistency in the new test files.

@erictang000
Copy link
Collaborator

was this intended to be added since there's already #898?

Address gemini-code-assist comments 2 & 3:
- Create WorkerDispatch before calling init_weight_sync_state
- Use dispatch.init_weight_sync_state(client) instead of calling
  directly on actor group
- This properly tests the WorkerDispatch API as designed
@tyler-griggs
Copy link
Member Author

Sorry I just got confused :D

@tyler-griggs
Copy link
Member Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively refactors the weight synchronization logic by introducing save_weights_for_sampler() in WorkerDispatch. This change centralizes the weight syncing process, simplifying the Trainer class and improving code clarity and maintainability, which aligns well with the stated goal of adopting the Tinker API pattern. The addition of comprehensive tests in test_save_weights_for_sampler.py is a great inclusion, ensuring the new functionality is robust across different configurations. The changes are well-implemented and a clear improvement.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@tyler-griggs tyler-griggs merged commit 48faf59 into main Jan 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants