[FlashRL 2/N] Support list of weights during weight sync for colocated training#161
Conversation
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…ht-sync Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an important optimization for weight syncing in colocated training by batching weight updates, which should significantly improve efficiency. The changes to support lists of weights in NamedWeightsUpdateRequest and the corresponding updates across the codebase are well-implemented. The PR also includes two valuable bug fixes related to trajectory-based routing and remote server initialization.
I've identified a few issues that need attention:
- A critical typo in
vllm_engine.pythat will likely cause aTypeErrorduring weight updates. - A high-severity bug in
deepspeed_worker.pywhere the batching logic for CUDA IPC transfers is not correctly implemented, as the batch size is not being tracked. - A medium-severity return type mismatch in
sglang_engine.py.
Addressing these points will ensure the new functionality is robust and works as expected. Overall, this is a great enhancement.
| if not success: | ||
| raise RuntimeError(f"Update weight request failed with message: {message}") | ||
| return |
There was a problem hiding this comment.
The function update_named_weights is type-hinted to return Tuple[bool, str], but this path returns None on success, which violates the function's contract. The IPC path correctly returns a tuple. To be consistent and correct, this should return a tuple on success, for example (True, "").
| if not success: | |
| raise RuntimeError(f"Update weight request failed with message: {message}") | |
| return | |
| if not success: | |
| raise RuntimeError(f"Update weight request failed with message: {message}") | |
| return True, "" |
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…d training (#161) # What does this PR do? Supports a list of weights during weight sync for colocated training. During colocated training, we use CUDA IPC for weight syncing. The current impl is syncing weights param by param, which can be pretty inefficient. In this PR, we sycn tensors in batches of a configurable parameter (default 1GB). That is, we collect ipc metadata until the total size of underlying tensors is 1GB and forward to the inference engine. Each TP rank will materialize all tensors in this list (i.e additional memory usage of 1GB here) and issue a single load_weights call. **How much faster is it?** Even for a 14B model on a 8xH100 node (TP2), the weight sync time can reduce from around 4.4s to 1.6s (60% reduction). This will matter much more for larger models. This PR is needed for the FlashRL integration to work well, because we have a custom load weights impl that - long story short - allcoates new storage in each call and also issues some `empty_cache` calls. Without batching, the load weights call will be too slow in such cases. This PR reduces time for weight sync for a 1.5B model with flashrl from 5 mins to < 5s. I've tested the PR with our E2E tests for colocated and non-colocated and also tested the remote engine codepath. This PR also makes the following changes: - Fixes bug introduced in #145 for the codepath with trajectory based routing when `response_ids` is not returned by the engine. - Fixes bug introduced in #126 for starting remote servers. import of `skyrl_train.utils.ppo_utils` will trigger registering. IN some cases, like with the vllm server init, we will not call `sync_registries` and there will be an error. The solution is to import guard `skyrl_train.utils.ppo_utils` unless the user themselves import it (for custom functions) or they go through the main entrypoint ( main -> `initialize_ray`-> sync) TODO: - [x] Verify non-colocated training works - [x] Run e2e test --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
What does this PR do?
Supports a list of weights during weight sync for colocated training. During colocated training, we use CUDA IPC for weight syncing. The current impl is syncing weights param by param, which can be pretty inefficient. In this PR, we sycn tensors in batches of a configurable parameter (default 1GB). That is, we collect ipc metadata until the total size of underlying tensors is 1GB and forward to the inference engine. Each TP rank will materialize all tensors in this list (i.e additional memory usage of 1GB here) and issue a single load_weights call.
How much faster is it?
Even for a 14B model on a 8xH100 node (TP2), the weight sync time can reduce from around 4.4s to 1.6s (60% reduction). This will matter much more for larger models.
This PR is needed for the FlashRL integration to work well, because we have a custom load weights impl that - long story short - allcoates new storage in each call and also issues some
empty_cachecalls. Without batching, the load weights call will be too slow in such cases. This PR reduces time for weight sync for a 1.5B model with flashrl from 5 mins to < 5s.I've tested the PR with our E2E tests for colocated and non-colocated and also tested the remote engine codepath.
This PR also makes the following changes:
response_idsis not returned by the engine.skyrl_train.utils.ppo_utilswill trigger registering. IN some cases, like with the vllm server init, we will not callsync_registriesand there will be an error. The solution is to import guardskyrl_train.utils.ppo_utilsunless the user themselves import it (for custom functions) or they go through the main entrypoint ( main ->initialize_ray-> sync)TODO: