-
Notifications
You must be signed in to change notification settings - Fork 238
[Tinker] Add save_weights_for_sampler() to WorkerDispatch #922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Move weight sync logic into WorkerDispatch as a single entry point. This aligns with Tinker's API pattern where users explicitly call save_weights_for_sampler() after training and before sampling. Changes: - WorkerDispatch: Add save_weights_for_sampler() async method that: - Prepares GPU state (offloads optimizer, keeps model on GPU) - Wakes inference engine for weight transfer (colocate_all only) - Broadcasts weights to inference engines - Offloads model after sync (colocate_all only) - Wakes inference engine for KV cache (colocate_all only) - Trainer: Replace two explicit weight sync blocks with: - dispatch.save_weights_for_sampler() after checkpoint load - dispatch.save_weights_for_sampler() after each training step - Tests: Add test_save_weights_for_sampler.py with: - E2E test: train → sync → sample (colocate and non-colocate) - Multiple training steps before single sync Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request significantly improves the codebase by encapsulating the weight synchronization logic into a single save_weights_for_sampler() method within WorkerDispatch. This refactoring greatly simplifies the Trainer class, making the training loop cleaner and more readable. Passing the inference_engine_client to the WorkerDispatch constructor is a good design choice for encapsulation. The addition of comprehensive end-to-end tests for the new functionality is also a major plus, ensuring the changes are robust. I have a couple of suggestions to improve consistency in the new test files.
|
was this intended to be added since there's already #898? |
Address gemini-code-assist comments 2 & 3: - Create WorkerDispatch before calling init_weight_sync_state - Use dispatch.init_weight_sync_state(client) instead of calling directly on actor group - This properly tests the WorkerDispatch API as designed
…gs/save_weights_for_sampler
|
Sorry I just got confused :D |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively refactors the weight synchronization logic by introducing save_weights_for_sampler() in WorkerDispatch. This change centralizes the weight syncing process, simplifying the Trainer class and improving code clarity and maintainability, which aligns well with the stated goal of adopting the Tinker API pattern. The addition of comprehensive tests in test_save_weights_for_sampler.py is a great inclusion, ensuring the new functionality is robust across different configurations. The changes are well-implemented and a clear improvement.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Summary
Adds
save_weights_for_sampler()to WorkerDispatch as the single entry point for syncing policy weights to inference engines before sampling. Now, calls tosave_weights_for_sampler()are the trigger for weight sync, rather than explicit weight synchronization logic in the trainer. This aligns with the Tinker API pattern where users explicitly call save_weights_for_sampler() after training and before sampling.Changes:
save_weights_for_sampler()(to reflect the Tinker library) that handles the full weight sync flow:Testing
Added
test_save_weights_for_samplertests, all passing.