[FlashRL 1/N] Add support for truncated importance sampling#145
[FlashRL 1/N] Add support for truncated importance sampling#145SumanthRH merged 26 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @SumanthRH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces foundational support for Truncated Importance Sampling (TIS) in the FlashRL framework. The primary goal is to enable more robust off-policy reinforcement learning by incorporating a mechanism to limit the impact of extreme importance ratios. This involves a comprehensive update to the generation and training pipelines to collect and utilize rollout log probabilities, which are crucial for TIS. New configuration options are provided to control TIS behavior, and a new example demonstrates its application. This change enhances the training stability and effectiveness of reinforcement learning algorithms within the system.
Highlights
- Implementation of Truncated Importance Sampling (TIS): This PR introduces support for Truncated Importance Sampling (TIS) within the policy loss calculation. This involves modifying the
ppo_policy_lossfunction to incorporaterollout_logprobsand apply a clamping mechanism to the importance ratio, as described in the FlashRL paper. - Enhanced Log Probability Collection: The system is now capable of collecting
rollout_logprobsfrom the inference engine. This required updates across various components, includingGeneratorOutput,InferenceEngineOutput,skyrl_gym_generator.py,inference_engine_client.py, andvllm_engine.py, to ensure these log probabilities are correctly captured and passed through the generation pipeline. - Configuration Updates for TIS and Logprob Collection: New configuration parameters,
use_tisandtis_imp_ratio_cap, have been added toppo_base_config.yamlto enable and control the TIS behavior. Additionally, aget_logprobsparameter was introduced to sampling configurations, allowing explicit control over log probability collection. - Integration of Logprobs into Data Handling and Training: The data handling pipeline, from preprocessing (
preprocess.py) to the replay buffer (replay_buffer.py) and the training input batch (training_batch.py), has been updated to seamlessly integrate the newrollout_logprobstensor. This ensures that the collected log probabilities are available for use in the training worker's policy loss calculation. - New Example for TIS with DAPO: A new example (
main_tis_dapo.pyandrun_dapo_tis.sh) has been added to demonstrate how to use TIS with DAPO (Differentiable Augmentation Policy Optimization) on the GSM8K dataset, providing a practical use case for the newly implemented features.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces support for truncated importance sampling (TIS) by adding an optional rollout_logprobs field throughout the data pipeline. The changes are extensive, affecting configuration, data processing, inference, and the training loop. The overall implementation appears solid, but I've identified a potential performance issue with vLLM sampling parameters, and several leftover debug print statements that should be removed.
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
| top_p: 1.0 | ||
| min_p: 0.0 | ||
| top_k: -1 | ||
| get_logprobs: false |
There was a problem hiding this comment.
The main reason for a custom name here is that there is no use for logprobs beyond those of the chosen token, while vllm has a much more flexible logprobs parameters. we'll eventually have differences in these sampling params anyways
There was a problem hiding this comment.
Could you help me understand this a little more? IIUC, we could currently use the logprobs param that vllm provides and set it to 0 so that we only get the logprobs of the chosen token and achieve the same behavior as this PR. Are there other differences that would cause a problem/overhead with this?
I agree that the sampling params will diverge, but trying to understand why they need to diverge in this case.
There was a problem hiding this comment.
Ah, I just felt that the illusion of having it be the same as vllm was a bit much - i.e we really just need logprobs for chosen token, but then vllm allows you to grab any number.
I could also keep it same as vllm and then have strict validation
There was a problem hiding this comment.
Yeah, that makes sense to me! I do think vllm's use of logprobs is pretty compatible here (None to disable, 0 to get the chosen token, and k to get k additional tokens). And for now we can just validate that it's either None or 0?
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for Truncated Importance Sampling (TIS), which requires passing rollout log probabilities through the system. The changes span across configuration, data processing, generation, inference, and training components to accommodate the new rollout_logprobs field. While the core logic for TIS seems correct, I've identified several issues, including a critical bug where a check is implemented as a non-functional string, type inconsistencies in data structures, and a few bugs in handling the new log probabilities that could lead to incorrect behavior or runtime errors. I've provided detailed comments and suggestions to address these points.
skyrl-train/skyrl_train/inference_engines/inference_engine_client.py
Outdated
Show resolved
Hide resolved
| "top_p": sampling_params.top_p, | ||
| "top_k": sampling_params.top_k, | ||
| "min_p": sampling_params.min_p, | ||
| "logprobs": 0 if sampling_params.get_logprobs else None, # 0 + 1 -> chosen token |
There was a problem hiding this comment.
Setting logprobs to 0 in vLLM's SamplingParams disables log probability reporting. To get the log probability of the sampled token, this should be set to a positive integer (e.g., 1). The comment // 0 + 1 -> chosen token is confusing and seems incorrect.
| "logprobs": 0 if sampling_params.get_logprobs else None, # 0 + 1 -> chosen token | |
| "logprobs": 1 if sampling_params.get_logprobs else None, |
| from skyrl_train.utils import str_to_torch_dtype | ||
|
|
||
|
|
||
| @dataclass |
There was a problem hiding this comment.
probably better practice to import the type directly, but tis a tiny dataclass
| class InferenceEngineOutput(TypedDict): | ||
| responses: List[str] | ||
| stop_reasons: List[str] | ||
| response_ids: Optional[List[List[int]]] |
There was a problem hiding this comment.
response ids are also taken from the inference engine output now because we need the response log probs to match exactly with the response ids....
otherwise, we will end up using the detokenized response string and then re-tokenizing it , which can again lead to mismatch since tokenization is not invertible.
| else: | ||
| self.env_executor = None | ||
|
|
||
| if getattr(self.generator_cfg.sampling_params, "get_logprobs") and not self.generator_cfg.batched: |
There was a problem hiding this comment.
qq: I assume this is because we are waiting on adding token-in-token-out support for the multi-turn case (ie, agent loop)?
| class InferenceEngineOutput(TypedDict): | ||
| responses: List[str] | ||
| stop_reasons: List[str] | ||
| response_ids: Optional[List[List[int]]] |
| top_p: 1.0 | ||
| min_p: 0.0 | ||
| top_k: -1 | ||
| get_logprobs: false |
There was a problem hiding this comment.
Could you help me understand this a little more? IIUC, we could currently use the logprobs param that vllm provides and set it to 0 so that we only get the logprobs of the chosen token and achieve the same behavior as this PR. Are there other differences that would cause a problem/overhead with this?
I agree that the sampling params will diverge, but trying to understand why they need to diverge in this case.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for Truncated Importance Sampling (TIS) by propagating rollout_logprobs through the system. The changes are extensive, touching configuration, data processing, inference engines, and the training loop. Overall, the implementation is solid, but there are a few areas that could be made more robust, particularly around handling optional data fields in distributed settings and ensuring consistency in applying the new TIS logic across different loss functions. I've provided specific suggestions to address these points.
skyrl-train/skyrl_train/inference_engines/inference_engine_client.py
Outdated
Show resolved
Hide resolved
tyler-griggs
left a comment
There was a problem hiding this comment.
Just one request to add new config params to the docs. Otherwise, LGTM
| # dual clip parameters | ||
| clip_ratio_c: 3.0 | ||
| # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl | ||
| tis_imp_ratio_cap: -1.0 |
There was a problem hiding this comment.
Add new config params to the config docs?
…d training (#161) # What does this PR do? Supports a list of weights during weight sync for colocated training. During colocated training, we use CUDA IPC for weight syncing. The current impl is syncing weights param by param, which can be pretty inefficient. In this PR, we sycn tensors in batches of a configurable parameter (default 1GB). That is, we collect ipc metadata until the total size of underlying tensors is 1GB and forward to the inference engine. Each TP rank will materialize all tensors in this list (i.e additional memory usage of 1GB here) and issue a single load_weights call. **How much faster is it?** Even for a 14B model on a 8xH100 node (TP2), the weight sync time can reduce from around 4.4s to 1.6s (60% reduction). This will matter much more for larger models. This PR is needed for the FlashRL integration to work well, because we have a custom load weights impl that - long story short - allcoates new storage in each call and also issues some `empty_cache` calls. Without batching, the load weights call will be too slow in such cases. This PR reduces time for weight sync for a 1.5B model with flashrl from 5 mins to < 5s. I've tested the PR with our E2E tests for colocated and non-colocated and also tested the remote engine codepath. This PR also makes the following changes: - Fixes bug introduced in #145 for the codepath with trajectory based routing when `response_ids` is not returned by the engine. - Fixes bug introduced in #126 for starting remote servers. import of `skyrl_train.utils.ppo_utils` will trigger registering. IN some cases, like with the vllm server init, we will not call `sync_registries` and there will be an error. The solution is to import guard `skyrl_train.utils.ppo_utils` unless the user themselves import it (for custom functions) or they go through the main entrypoint ( main -> `initialize_ray`-> sync) TODO: - [x] Verify non-colocated training works - [x] Run e2e test --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
# What does this PR do? Adds support for truncated importance sampling proposed in https://fengyao.notion.site/flash-rl To support this, we need to add a an optional `rollout_logprobs` fields at the inference engine. Note that because the default path has no use for rollout logprobs it is not obtained by default and is populated as `None`. There is some ugliness with just being able to handle both these cases (no logprobs and with logprobs) but it's fine for now. TODO: - [x] Run E2E GPU tests for vllm and sglang - [x] Test agent loop path with tis (requires token in, token out fix - not taken up for now) --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…d training (#161) # What does this PR do? Supports a list of weights during weight sync for colocated training. During colocated training, we use CUDA IPC for weight syncing. The current impl is syncing weights param by param, which can be pretty inefficient. In this PR, we sycn tensors in batches of a configurable parameter (default 1GB). That is, we collect ipc metadata until the total size of underlying tensors is 1GB and forward to the inference engine. Each TP rank will materialize all tensors in this list (i.e additional memory usage of 1GB here) and issue a single load_weights call. **How much faster is it?** Even for a 14B model on a 8xH100 node (TP2), the weight sync time can reduce from around 4.4s to 1.6s (60% reduction). This will matter much more for larger models. This PR is needed for the FlashRL integration to work well, because we have a custom load weights impl that - long story short - allcoates new storage in each call and also issues some `empty_cache` calls. Without batching, the load weights call will be too slow in such cases. This PR reduces time for weight sync for a 1.5B model with flashrl from 5 mins to < 5s. I've tested the PR with our E2E tests for colocated and non-colocated and also tested the remote engine codepath. This PR also makes the following changes: - Fixes bug introduced in #145 for the codepath with trajectory based routing when `response_ids` is not returned by the engine. - Fixes bug introduced in #126 for starting remote servers. import of `skyrl_train.utils.ppo_utils` will trigger registering. IN some cases, like with the vllm server init, we will not call `sync_registries` and there will be an error. The solution is to import guard `skyrl_train.utils.ppo_utils` unless the user themselves import it (for custom functions) or they go through the main entrypoint ( main -> `initialize_ray`-> sync) TODO: - [x] Verify non-colocated training works - [x] Run e2e test --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com>



What does this PR do?
Adds support for truncated importance sampling proposed in https://fengyao.notion.site/flash-rl
To support this, we need to add a an optional
rollout_logprobsfields at the inference engine. Note that because the default path has no use for rollout logprobs it is not obtained by default and is populated asNone. There is some ugliness with just being able to handle both these cases (no logprobs and with logprobs) but it's fine for now.TODO: