Skip to content

[FlashRL 1/N] Add support for truncated importance sampling#145

Merged
SumanthRH merged 26 commits intoNovaSky-AI:mainfrom
SumanthRH:t_i_s
Aug 15, 2025
Merged

[FlashRL 1/N] Add support for truncated importance sampling#145
SumanthRH merged 26 commits intoNovaSky-AI:mainfrom
SumanthRH:t_i_s

Conversation

@SumanthRH
Copy link
Member

@SumanthRH SumanthRH commented Aug 12, 2025

What does this PR do?

Adds support for truncated importance sampling proposed in https://fengyao.notion.site/flash-rl

To support this, we need to add a an optional rollout_logprobs fields at the inference engine. Note that because the default path has no use for rollout logprobs it is not obtained by default and is populated as None. There is some ugliness with just being able to handle both these cases (no logprobs and with logprobs) but it's fine for now.

TODO:

  • Run E2E GPU tests for vllm and sglang
  • Test agent loop path with tis (requires token in, token out fix - not taken up for now)

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @SumanthRH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces foundational support for Truncated Importance Sampling (TIS) in the FlashRL framework. The primary goal is to enable more robust off-policy reinforcement learning by incorporating a mechanism to limit the impact of extreme importance ratios. This involves a comprehensive update to the generation and training pipelines to collect and utilize rollout log probabilities, which are crucial for TIS. New configuration options are provided to control TIS behavior, and a new example demonstrates its application. This change enhances the training stability and effectiveness of reinforcement learning algorithms within the system.

Highlights

  • Implementation of Truncated Importance Sampling (TIS): This PR introduces support for Truncated Importance Sampling (TIS) within the policy loss calculation. This involves modifying the ppo_policy_loss function to incorporate rollout_logprobs and apply a clamping mechanism to the importance ratio, as described in the FlashRL paper.
  • Enhanced Log Probability Collection: The system is now capable of collecting rollout_logprobs from the inference engine. This required updates across various components, including GeneratorOutput, InferenceEngineOutput, skyrl_gym_generator.py, inference_engine_client.py, and vllm_engine.py, to ensure these log probabilities are correctly captured and passed through the generation pipeline.
  • Configuration Updates for TIS and Logprob Collection: New configuration parameters, use_tis and tis_imp_ratio_cap, have been added to ppo_base_config.yaml to enable and control the TIS behavior. Additionally, a get_logprobs parameter was introduced to sampling configurations, allowing explicit control over log probability collection.
  • Integration of Logprobs into Data Handling and Training: The data handling pipeline, from preprocessing (preprocess.py) to the replay buffer (replay_buffer.py) and the training input batch (training_batch.py), has been updated to seamlessly integrate the new rollout_logprobs tensor. This ensures that the collected log probabilities are available for use in the training worker's policy loss calculation.
  • New Example for TIS with DAPO: A new example (main_tis_dapo.py and run_dapo_tis.sh) has been added to demonstrate how to use TIS with DAPO (Differentiable Augmentation Policy Optimization) on the GSM8K dataset, providing a practical use case for the newly implemented features.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for truncated importance sampling (TIS) by adding an optional rollout_logprobs field throughout the data pipeline. The changes are extensive, affecting configuration, data processing, inference, and the training loop. The overall implementation appears solid, but I've identified a potential performance issue with vLLM sampling parameters, and several leftover debug print statements that should be removed.

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@NovaSky-AI NovaSky-AI deleted a comment from gemini-code-assist bot Aug 12, 2025
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
top_p: 1.0
min_p: 0.0
top_k: -1
get_logprobs: false
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for a custom name here is that there is no use for logprobs beyond those of the chosen token, while vllm has a much more flexible logprobs parameters. we'll eventually have differences in these sampling params anyways

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand this a little more? IIUC, we could currently use the logprobs param that vllm provides and set it to 0 so that we only get the logprobs of the chosen token and achieve the same behavior as this PR. Are there other differences that would cause a problem/overhead with this?

I agree that the sampling params will diverge, but trying to understand why they need to diverge in this case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I just felt that the illusion of having it be the same as vllm was a bit much - i.e we really just need logprobs for chosen token, but then vllm allows you to grab any number.

I could also keep it same as vllm and then have strict validation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense to me! I do think vllm's use of logprobs is pretty compatible here (None to disable, 0 to get the chosen token, and k to get k additional tokens). And for now we can just validate that it's either None or 0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes on it!

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH
Copy link
Member Author

/gemini review

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Truncated Importance Sampling (TIS), which requires passing rollout log probabilities through the system. The changes span across configuration, data processing, generation, inference, and training components to accommodate the new rollout_logprobs field. While the core logic for TIS seems correct, I've identified several issues, including a critical bug where a check is implemented as a non-functional string, type inconsistencies in data structures, and a few bugs in handling the new log probabilities that could lead to incorrect behavior or runtime errors. I've provided detailed comments and suggestions to address these points.

"top_p": sampling_params.top_p,
"top_k": sampling_params.top_k,
"min_p": sampling_params.min_p,
"logprobs": 0 if sampling_params.get_logprobs else None, # 0 + 1 -> chosen token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting logprobs to 0 in vLLM's SamplingParams disables log probability reporting. To get the log probability of the sampled token, this should be set to a positive integer (e.g., 1). The comment // 0 + 1 -> chosen token is confusing and seems incorrect.

Suggested change
"logprobs": 0 if sampling_params.get_logprobs else None, # 0 + 1 -> chosen token
"logprobs": 1 if sampling_params.get_logprobs else None,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
from skyrl_train.utils import str_to_torch_dtype


@dataclass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taken from vLLM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better practice to import the type directly, but tis a tiny dataclass

class InferenceEngineOutput(TypedDict):
responses: List[str]
stop_reasons: List[str]
response_ids: Optional[List[List[int]]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response ids are also taken from the inference engine output now because we need the response log probs to match exactly with the response ids....

otherwise, we will end up using the detokenized response string and then re-tokenizing it , which can again lead to mismatch since tokenization is not invertible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense.

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
else:
self.env_executor = None

if getattr(self.generator_cfg.sampling_params, "get_logprobs") and not self.generator_cfg.batched:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: I assume this is because we are waiting on adding token-in-token-out support for the multi-turn case (ie, agent loop)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

class InferenceEngineOutput(TypedDict):
responses: List[str]
stop_reasons: List[str]
response_ids: Optional[List[List[int]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense.

top_p: 1.0
min_p: 0.0
top_k: -1
get_logprobs: false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand this a little more? IIUC, we could currently use the logprobs param that vllm provides and set it to 0 so that we only get the logprobs of the chosen token and achieve the same behavior as this PR. Are there other differences that would cause a problem/overhead with this?

I agree that the sampling params will diverge, but trying to understand why they need to diverge in this case.

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH requested a review from tyler-griggs August 14, 2025 22:52
@SumanthRH
Copy link
Member Author

With BF16, the difference in inference and trainer probabilities is small, so almost no change is expected.

I verified evaluation accuracy with and withotu TIS for the DAPO recipe we have on GSM8K. Curves match (note that number of steps is a bit differnt for both)

Screenshot 2025-08-13 at 11 48 51 AM Screenshot 2025-08-13 at 11 49 08 AM

@SumanthRH
Copy link
Member Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Truncated Importance Sampling (TIS) by propagating rollout_logprobs through the system. The changes are extensive, touching configuration, data processing, inference engines, and the training loop. Overall, the implementation is solid, but there are a few areas that could be made more robust, particularly around handling optional data fields in distributed settings and ensuring consistency in applying the new TIS logic across different loss functions. I've provided specific suggestions to address these points.

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH
Copy link
Member Author

image

New metrics when logprobs is enabled. Note that sampling_params.logprobs can be enabled even without TIS. There can be better documentation on this, but will leave it for later PR

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one request to add new config params to the docs. Otherwise, LGTM

# dual clip parameters
clip_ratio_c: 3.0
# Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl
tis_imp_ratio_cap: -1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add new config params to the config docs?

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH merged commit 81f2973 into NovaSky-AI:main Aug 15, 2025
1 check passed
SumanthRH added a commit that referenced this pull request Aug 15, 2025
…e() " (#150)

# What does this PR do?


Reverts #138 since it's incorrect after #145 - `rollout_logprobs` is
also returned now and it's dtype is `Optional[List[float]]`.

Also simplifies original approach a bit

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
SumanthRH added a commit that referenced this pull request Aug 19, 2025
…d training (#161)

# What does this PR do?

Supports a list of weights during weight sync for colocated training.
During colocated training, we use CUDA IPC for weight syncing. The
current impl is syncing weights param by param, which can be pretty
inefficient. In this PR, we sycn tensors in batches of a configurable
parameter (default 1GB). That is, we collect ipc metadata until the
total size of underlying tensors is 1GB and forward to the inference
engine. Each TP rank will materialize all tensors in this list (i.e
additional memory usage of 1GB here) and issue a single load_weights
call.

**How much faster is it?**

Even for a 14B model on a 8xH100 node (TP2), the weight sync time can
reduce from around 4.4s to 1.6s (60% reduction). This will matter much
more for larger models.

This PR is needed for the FlashRL integration to work well, because we
have a custom load weights impl that - long story short - allcoates new
storage in each call and also issues some `empty_cache` calls. Without
batching, the load weights call will be too slow in such cases. This PR
reduces time for weight sync for a 1.5B model with flashrl from 5 mins
to < 5s.

I've tested the PR with our E2E tests for colocated and non-colocated
and also tested the remote engine codepath.

This PR also makes the following changes:
- Fixes bug introduced in #145 for the codepath with trajectory based
routing when `response_ids` is not returned by the engine.
- Fixes bug introduced in #126 for starting remote servers. import of
`skyrl_train.utils.ppo_utils` will trigger registering. IN some cases,
like with the vllm server init, we will not call `sync_registries` and
there will be an error. The solution is to import guard
`skyrl_train.utils.ppo_utils` unless the user themselves import it (for
custom functions) or they go through the main entrypoint ( main ->
`initialize_ray`-> sync)


TODO:
- [x] Verify non-colocated training works
- [x]  Run e2e test

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
dzorlu referenced this pull request in fleet-ai/SkyRL Feb 4, 2026
# What does this PR do? 

Adds support for truncated importance sampling proposed in
https://fengyao.notion.site/flash-rl

To support this, we need to add a an optional `rollout_logprobs` fields
at the inference engine. Note that because the default path has no use
for rollout logprobs it is not obtained by default and is populated as
`None`. There is some ugliness with just being able to handle both these
cases (no logprobs and with logprobs) but it's fine for now.


TODO: 
- [x] Run E2E GPU tests for vllm and sglang
- [x] Test agent loop path with tis (requires token in, token out fix -
not taken up for now)

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
dzorlu referenced this pull request in fleet-ai/SkyRL Feb 4, 2026
…e() " (#150)

# What does this PR do?


Reverts #138 since it's incorrect after #145 - `rollout_logprobs` is
also returned now and it's dtype is `Optional[List[float]]`.

Also simplifies original approach a bit

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
dzorlu referenced this pull request in fleet-ai/SkyRL Feb 4, 2026
…d training (#161)

# What does this PR do?

Supports a list of weights during weight sync for colocated training.
During colocated training, we use CUDA IPC for weight syncing. The
current impl is syncing weights param by param, which can be pretty
inefficient. In this PR, we sycn tensors in batches of a configurable
parameter (default 1GB). That is, we collect ipc metadata until the
total size of underlying tensors is 1GB and forward to the inference
engine. Each TP rank will materialize all tensors in this list (i.e
additional memory usage of 1GB here) and issue a single load_weights
call.

**How much faster is it?**

Even for a 14B model on a 8xH100 node (TP2), the weight sync time can
reduce from around 4.4s to 1.6s (60% reduction). This will matter much
more for larger models.

This PR is needed for the FlashRL integration to work well, because we
have a custom load weights impl that - long story short - allcoates new
storage in each call and also issues some `empty_cache` calls. Without
batching, the load weights call will be too slow in such cases. This PR
reduces time for weight sync for a 1.5B model with flashrl from 5 mins
to < 5s.

I've tested the PR with our E2E tests for colocated and non-colocated
and also tested the remote engine codepath.

This PR also makes the following changes:
- Fixes bug introduced in #145 for the codepath with trajectory based
routing when `response_ids` is not returned by the engine.
- Fixes bug introduced in #126 for starting remote servers. import of
`skyrl_train.utils.ppo_utils` will trigger registering. IN some cases,
like with the vllm server init, we will not call `sync_registries` and
there will be an error. The solution is to import guard
`skyrl_train.utils.ppo_utils` unless the user themselves import it (for
custom functions) or they go through the main entrypoint ( main ->
`initialize_ray`-> sync)


TODO:
- [x] Verify non-colocated training works
- [x]  Run e2e test

---------

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants