Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions skyrl-train/docs/examples/flash_rl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
FlashRL + SkyRL: Training with FP8 Rollouts
===========================================

In this example, we walk through how to train a model with FP8 rollouts using `FlashRL <https://fengyao.notion.site/flash-rl>`_ and SkyRL.

We provide an example for training Qwen2.5-1.5B-Instruct and Qwen3-32B with FP8 rollouts.

What is FlashRL?
----------------

FlashRL is a novel method that provides the first RL recipe with quantized rollout generation while preserving downstream performance. FlashRL consists of two main components:

- Truncated Importance Sampling (TIS): In scalable RL frameworks, policy model and rollout are typically managed by different libraries/ frameworks (FSDP and vLLM, resp.), which leads to a mismatch between the probability distributions. TIS is a technique that solves the rollout and training mismatch problem by applying a token-level correction factor (based on the importance-sampling ratio) to the policy loss.
- Online Quantization Support: While vLLM has support for inference with quantized weights, it is tricky to use this for RL training. FlashRL also has patches for vLLM to support weight updates for FP8 and Int8 during training.


FlashRL + SkyRL
---------------

SkyRL now supports an initial integration with FlashRL. Currently, we only support training with `online FP8 quantization <https://docs.vllm.ai/en/v0.9.2/features/quantization/fp8.html#online-dynamic-quantization>`_ in vLLM. You should simply specify ``FLASHRL_CONFIG=fp8_vllm`` in your environment variables and use the ``--extra flashrl`` flag when running the training script.


.. warning::

FlashRL integration only supports single-turn training at the moment.


How does it work?
~~~~~~~~~~~~~~~~~~

We pass `quantization=fp8` flag to the vLLM engine at initialization time. This means that the weights are loaded as usual in half precision and then quantized down to fp8. During training, generations are sampled as usual, and in this case, sampled from quantized weights. Since we use online quantization, the scale factor used for quantizing activations are computed on the fly by vLLM internally.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We pass `quantization=fp8` flag to the vLLM engine at initialization time. This means that the weights are loaded as usual in half precision and then quantized down to fp8. During training, generations are sampled as usual, and in this case, sampled from quantized weights. Since we use online quantization, the scale factor used for quantizing activations are computed on the fly by vLLM internally.
We pass the `quantization=fp8` flag to the vLLM engine at initialization time. This means that the weights are loaded as usual in half precision and then quantized down to FP8. During training, generations are sampled as usual, but now from the quantized weights. Since vLLM uses online quantization, the scale factors used for quantizing activations are computed dynamically during runtime.


The sampled rollouts are then used to compute the policy loss. We further apply the TIS correction factor to the policy loss and then update the policy model weights. These weights, in half precision, are then synced with the inference engine layer by layer. These are then loaded and quantized down to fp8 similar to how we quantized the weights at initialization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The sampled rollouts are then used to compute the policy loss. We further apply the TIS correction factor to the policy loss and then update the policy model weights. These weights, in half precision, are then synced with the inference engine layer by layer. These are then loaded and quantized down to fp8 similar to how we quantized the weights at initialization.
The sampled rollouts are then used to compute the policy loss. We further apply the TIS correction factor to the policy loss and then update the policy model weights. These weights, in half precision, are then synced with the inference engine layer by layer. These are then loaded and quantized down to FP8 similar to how we quantized the weights at initialization.



Example
--------

We provide two examples for training with FP8 rollouts for DAPO: one for training Qwen2.5-1.5B-Instruct and one for Qwen3-32B. The FlashRL related files are in ``skyrl_train/examples/flash_rl/`` folder.


.. code-block:: bash
:caption: Training configuration at ``skyrl_train/examples/flash_rl/run_dapo_flashrl.sh``

# path for dataset (.parquet files) containing the prompts and metadata for each question
DATA_DIR="$HOME/data/gsm8k"

uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.flashrl -m examples.flash_rl.main_dapo_flashrl \
...
trainer.algorithm.use_tis=true \
trainer.algorithm.tis_imp_ratio_cap=2.0 \
...

Here, we've configured training to use TIS with the importance sampling ratio cap of 2.0. Note that for making sure the FlashRL patches are applied for vLLM, we use the ``FLASHRL_CONFIG`` env var in ``examples/flash_rl/.env.flashrl``:

.. code-block:: bash
:caption: Environment variables at ``examples/flash_rl/.env.flashrl``

FLASHRL_CONFIG=fp8_vllm
...

.. warning::

FlashRL integration is experimental. While generation times can improve for large models with quantization, we've observed that the time spent in weight syncing is much higher with FlashRL for fp8. This negates some of the benefits of fp8 inference. The slowdown is primarily due to slow weight quantization in vLLM's ``process_weights_after_loading`` function and we are actively working on improving this.
1 change: 1 addition & 0 deletions skyrl-train/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ SkyRL is a full-stack RL library designed for modularity and extensibility.
examples/training_backends
examples/multi_turn_text2sql
examples/search
examples/flash_rl

.. toctree::
:maxdepth: 2
Expand Down
2 changes: 2 additions & 0 deletions skyrl-train/examples/flash_rl/.env.flashrl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FLASHRL_LOGGING_LEVEL=DEBUG # optional
FLASHRL_CONFIG="fp8_vllm"
126 changes: 126 additions & 0 deletions skyrl-train/examples/flash_rl/flash_rl_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Any, Dict, Optional
import ray
import vllm
from skyrl_train.inference_engines.vllm.vllm_engine import VLLMInferenceEngine
from skyrl_train.inference_engines.ray_wrapped_inference_engine import RayWrappedInferenceEngine
from ray.util.placement_group import PlacementGroupSchedulingStrategy, placement_group

from skyrl_train.inference_engines.base import (
InferenceEngineInterface,
)


class FlashRLVLLMInferenceEngine(VLLMInferenceEngine):

def _create_engine(self, *args, **kwargs):
# apply flashrl's patch just before init
from vllm.model_executor.layers.patch import apply_patch as apply_flashrl_patch

apply_flashrl_patch()

llm = vllm.LLM(*args, **kwargs)
return llm


VLLMRayActor = ray.remote(FlashRLVLLMInferenceEngine)


def create_ray_wrapped_inference_engines_flashrl(
num_inference_engines: int,
tensor_parallel_size: int,
model_dtype: str,
pretrain: str,
seed: int,
vllm_v1_disable_multiproc: bool,
enable_prefix_caching: bool,
enforce_eager: bool,
max_model_len: int,
shared_pg=None,
gpu_memory_utilization=None,
inference_engine_enable_sleep=False,
async_engine=False,
max_num_batched_tokens=8192,
max_num_seqs=1024,
sampling_params: Optional[Dict[str, Any]] = None,
tokenizer=None,
backend="vllm",
) -> List[InferenceEngineInterface]:
"""
Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface instances.
"""
from skyrl_train.utils import ray_noset_visible_devices, get_all_env_variables, get_ray_pg_ready_with_timeout

assert not async_engine, "`async_engine` is not supported for FlashRL"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm - we can only use the offline engine for flash-rl, so only single turn rollouts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe worth a clarification in the doc, I didn't realize until i hit this line of code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let me add a warning


if backend != "vllm":
raise ValueError(f"Unsupported FlashRL backend: {backend}")

inference_engine_actors = []
noset_visible_devices = ray_noset_visible_devices(ray.get(get_all_env_variables.remote()))
# NOTE: we use the ray backend for tensor parallel size > 1 to explicitly manage resource allocation
# TODO: we should be able to support mp backend by allocating resources at engine level
distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray"
use_hybrid_engine = shared_pg is not None
num_gpus = int(tensor_parallel_size == 1)
if use_hybrid_engine and tensor_parallel_size == 1:
# every worker will use 0.2 GPU, so that we can schedule
# 2 instances on the same GPUs.
num_gpus = 0.2

if not use_hybrid_engine:
# Create a big placement group to ensure that all inference engines are packed
bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_inference_engines * tensor_parallel_size)]
shared_pg = placement_group(bundles, strategy="PACK")
get_ray_pg_ready_with_timeout(shared_pg, timeout=30)

for i in range(num_inference_engines):
bundle_indices = None
if tensor_parallel_size > 1:
bundle_indices = list(range(i * tensor_parallel_size, (i + 1) * tensor_parallel_size))

scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=shared_pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=i * tensor_parallel_size,
)

if backend == "vllm":

engine = VLLMRayActor.options(
num_cpus=num_gpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
).remote(
model=pretrain,
enforce_eager=enforce_eager,
worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap",
tensor_parallel_size=tensor_parallel_size,
seed=seed + i,
distributed_executor_backend=distributed_executor_backend,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
dtype=model_dtype,
trust_remote_code=True,
vllm_v1_disable_multiproc=vllm_v1_disable_multiproc,
gpu_memory_utilization=gpu_memory_utilization,
bundle_indices=bundle_indices,
num_gpus=0.2 if use_hybrid_engine else 1,
enable_sleep_mode=inference_engine_enable_sleep,
noset_visible_devices=noset_visible_devices,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
sampling_params=sampling_params,
tokenizer=tokenizer,
# only need the logprobs for the chosen token if any
max_logprobs=1,
)

inference_engine_actors.append(engine)

engines = [RayWrappedInferenceEngine(actor_handle) for actor_handle in inference_engine_actors]

if inference_engine_enable_sleep:
sleep_refs = [engine.inference_engine_actor.sleep.remote() for engine in engines]
ray.get(sleep_refs)

return engines
191 changes: 191 additions & 0 deletions skyrl-train/examples/flash_rl/main_dapo_flashrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
uv run --isolated --extra vllm -m examples.flash_rl.main_dapo_flashrl
"""

import ray
import os
import hydra
import torch
from typing import List
from omegaconf import DictConfig
from skyrl_train.trainer import RayPPOTrainer
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import (
BasePPOExp,
config_dir,
validate_cfg,
create_remote_inference_engines_from_config,
)
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
from skyrl_train.inference_engines.utils import get_sampling_params_for_backend
from skyrl_train.generators.base import GeneratorInterface
from loguru import logger
from skyrl_train.generators.base import GeneratorOutput


def create_ray_wrapped_inference_engines_from_config_flashrl(cfg: DictConfig, colocate_pg, tokenizer):
from examples.flash_rl.flash_rl_engine import create_ray_wrapped_inference_engines_flashrl

return create_ray_wrapped_inference_engines_flashrl(
num_inference_engines=cfg.generator.num_inference_engines,
tensor_parallel_size=cfg.generator.inference_engine_tensor_parallel_size,
model_dtype=cfg.generator.model_dtype,
pretrain=cfg.trainer.policy.model.path,
seed=cfg.trainer.seed,
vllm_v1_disable_multiproc=cfg.generator.vllm_v1_disable_multiproc,
enable_prefix_caching=cfg.generator.enable_prefix_caching,
enforce_eager=cfg.generator.enforce_eager,
max_model_len=cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length,
shared_pg=colocate_pg,
gpu_memory_utilization=cfg.generator.gpu_memory_utilization,
inference_engine_enable_sleep=cfg.trainer.placement.colocate_all,
async_engine=cfg.generator.async_engine,
max_num_batched_tokens=cfg.generator.max_num_batched_tokens,
max_num_seqs=cfg.generator.max_num_seqs,
sampling_params=get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params),
tokenizer=tokenizer,
backend=cfg.generator.backend,
)


class DAPOTrainer(RayPPOTrainer):
"""
Custom trainer for DAPO.

Overrides the postprocess_generator_output method to additionally apply soft overlong punishment to rewards.
"""

@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
"""
Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards.

Args:
generator_output: GeneratorOutput
uids: List[str]

Returns:
GeneratorOutput
"""
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor
# modify rewards here
prompt_token_ids = generator_output["prompt_token_ids"]
response_ids = generator_output["response_ids"]
rewards = generator_output["rewards"]

assert not isinstance(rewards[0], list), "we assume verifiable sequence level rewards here"

# get the prompt length
prompt_lengths = [len(prompt) for prompt in prompt_token_ids]

# get the response length
response_lengths = [len(response) for response in response_ids]

# get the max context length
max_context_length = (
self.cfg.generator.max_input_length + self.cfg.generator.sampling_params.max_generate_length
)

# apply soft overlong punishment
for i, (prompt_length, response_length) in enumerate(zip(prompt_lengths, response_lengths)):
# max_exceed_length is the beginning of the overlong buffer
max_exceed_length = max_context_length - overlong_buffer_len - prompt_length
# if the response is within the overlong buffer, apply the penalty
if response_length > max_exceed_length and response_length <= max_context_length - prompt_length:
exceed_length = response_length - max_exceed_length
penalty = exceed_length / overlong_buffer_len * overlong_buffer_penalty_factor

rewards[i] -= penalty
# if the response is outside the overlong buffer, set the reward to 0
elif response_length > max_context_length - prompt_length:
# if self.cfg.generator.apply_overlong_filtering is true, loss masks are already set to 0 for these responses
rewards[i] = 0.0

generator_output["rewards"] = rewards

# use base class impl for metrics and per-token reward conversion
return super().postprocess_generator_output(generator_output, uids)


class DAPOExp(BasePPOExp):
def get_trainer(self, *args, **kwargs):
return DAPOTrainer(*args, **kwargs)

def _setup_trainer(self):
"""Setup and return the trainer.

Instantiates the trainer and all the associated models for training.

Returns:
RayPPOTrainer: The trainer.
"""
logger.info(self.get_cfg_as_str(self.cfg))
os.makedirs(self.cfg.trainer.export_path, exist_ok=True)
os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True)

if self.cfg.trainer.strategy == "deepspeed":
from skyrl_train.workers.deepspeed.deepspeed_worker import (
PolicyWorker,
CriticWorker,
RefWorker,
RewardWorker,
)
elif self.cfg.trainer.strategy in ("fsdp", "fsdp2"):
from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker, CriticWorker, RefWorker, RewardWorker
else:
raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}")

# NOTE (sumanthrh): Instantiate tracker before trainer init.
# We have custom validation before this step to give better error messages.
tracker = self.get_tracker()

tokenizer = self.tokenizer
if self.cfg.generator.run_engines_locally:
inference_engines = create_ray_wrapped_inference_engines_from_config_flashrl(
self.cfg, self.colocate_pg, tokenizer
)
else:
inference_engines = create_remote_inference_engines_from_config(self.cfg)

inference_engine_client = InferenceEngineClient(inference_engines)

generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client)

trainer = self.get_trainer(
cfg=self.cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=self.colocate_pg,
)

# Build the models
trainer.build_models(PolicyWorker, CriticWorker, RefWorker, RewardWorker)
return trainer


@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):

exp = DAPOExp(cfg)
exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

if not cfg.generator.run_engines_locally:
raise ValueError("FlashRL only supports colocated training.")

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))


if __name__ == "__main__":
main()
Loading