Skip to content

[FlashRL 2/N] Support list of weights during weight sync for colocated training#161

Merged
SumanthRH merged 15 commits intoNovaSky-AI:mainfrom
SumanthRH:sumanthrh/impr-weight-sync
Aug 19, 2025
Merged

[FlashRL 2/N] Support list of weights during weight sync for colocated training#161
SumanthRH merged 15 commits intoNovaSky-AI:mainfrom
SumanthRH:sumanthrh/impr-weight-sync

Conversation

@SumanthRH
Copy link
Member

@SumanthRH SumanthRH commented Aug 19, 2025

What does this PR do?

Supports a list of weights during weight sync for colocated training. During colocated training, we use CUDA IPC for weight syncing. The current impl is syncing weights param by param, which can be pretty inefficient. In this PR, we sycn tensors in batches of a configurable parameter (default 1GB). That is, we collect ipc metadata until the total size of underlying tensors is 1GB and forward to the inference engine. Each TP rank will materialize all tensors in this list (i.e additional memory usage of 1GB here) and issue a single load_weights call.

How much faster is it?

Even for a 14B model on a 8xH100 node (TP2), the weight sync time can reduce from around 4.4s to 1.6s (60% reduction). This will matter much more for larger models.

This PR is needed for the FlashRL integration to work well, because we have a custom load weights impl that - long story short - allcoates new storage in each call and also issues some empty_cache calls. Without batching, the load weights call will be too slow in such cases. This PR reduces time for weight sync for a 1.5B model with flashrl from 5 mins to < 5s.

I've tested the PR with our E2E tests for colocated and non-colocated and also tested the remote engine codepath.

This PR also makes the following changes:

TODO:

  • Verify non-colocated training works
  • Run e2e test

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…ht-sync

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH marked this pull request as ready for review August 19, 2025 09:24
@SumanthRH
Copy link
Member Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an important optimization for weight syncing in colocated training by batching weight updates, which should significantly improve efficiency. The changes to support lists of weights in NamedWeightsUpdateRequest and the corresponding updates across the codebase are well-implemented. The PR also includes two valuable bug fixes related to trajectory-based routing and remote server initialization.

I've identified a few issues that need attention:

  • A critical typo in vllm_engine.py that will likely cause a TypeError during weight updates.
  • A high-severity bug in deepspeed_worker.py where the batching logic for CUDA IPC transfers is not correctly implemented, as the batch size is not being tracked.
  • A medium-severity return type mismatch in sglang_engine.py.

Addressing these points will ensure the new functionality is robust and works as expected. Overall, this is a great enhancement.

Comment on lines +313 to +315
if not success:
raise RuntimeError(f"Update weight request failed with message: {message}")
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function update_named_weights is type-hinted to return Tuple[bool, str], but this path returns None on success, which violates the function's contract. The IPC path correctly returns a tuple. To be consistent and correct, this should return a tuple on success, for example (True, "").

Suggested change
if not success:
raise RuntimeError(f"Update weight request failed with message: {message}")
return
if not success:
raise RuntimeError(f"Update weight request failed with message: {message}")
return True, ""

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH merged commit d07f293 into NovaSky-AI:main Aug 19, 2025
3 checks passed
dzorlu referenced this pull request in fleet-ai/SkyRL Feb 4, 2026
…d training (#161)

# What does this PR do?

Supports a list of weights during weight sync for colocated training.
During colocated training, we use CUDA IPC for weight syncing. The
current impl is syncing weights param by param, which can be pretty
inefficient. In this PR, we sycn tensors in batches of a configurable
parameter (default 1GB). That is, we collect ipc metadata until the
total size of underlying tensors is 1GB and forward to the inference
engine. Each TP rank will materialize all tensors in this list (i.e
additional memory usage of 1GB here) and issue a single load_weights
call.

**How much faster is it?**

Even for a 14B model on a 8xH100 node (TP2), the weight sync time can
reduce from around 4.4s to 1.6s (60% reduction). This will matter much
more for larger models.

This PR is needed for the FlashRL integration to work well, because we
have a custom load weights impl that - long story short - allcoates new
storage in each call and also issues some `empty_cache` calls. Without
batching, the load weights call will be too slow in such cases. This PR
reduces time for weight sync for a 1.5B model with flashrl from 5 mins
to < 5s.

I've tested the PR with our E2E tests for colocated and non-colocated
and also tested the remote engine codepath.

This PR also makes the following changes:
- Fixes bug introduced in #145 for the codepath with trajectory based
routing when `response_ids` is not returned by the engine.
- Fixes bug introduced in #126 for starting remote servers. import of
`skyrl_train.utils.ppo_utils` will trigger registering. IN some cases,
like with the vllm server init, we will not call `sync_registries` and
there will be an error. The solution is to import guard
`skyrl_train.utils.ppo_utils` unless the user themselves import it (for
custom functions) or they go through the main entrypoint ( main ->
`initialize_ray`-> sync)


TODO:
- [x] Verify non-colocated training works
- [x]  Run e2e test

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants