-
Notifications
You must be signed in to change notification settings - Fork 261
[FlashRL 3/N] Add example for FP8 training with FlashRL #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e2fc6d5
a768356
373d77a
8f5daa7
a8f564c
a541732
7bb4e7c
9c3ec64
8ebd1a7
2e90cac
035651f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,64 @@ | ||||||
| FlashRL + SkyRL: Training with FP8 Rollouts | ||||||
| =========================================== | ||||||
|
|
||||||
| In this example, we walk through how to train a model with FP8 rollouts using `FlashRL <https://fengyao.notion.site/flash-rl>`_ and SkyRL. | ||||||
|
|
||||||
| We provide an example for training Qwen2.5-1.5B-Instruct and Qwen3-32B with FP8 rollouts. | ||||||
|
|
||||||
| What is FlashRL? | ||||||
| ---------------- | ||||||
|
|
||||||
| FlashRL is a novel method that provides the first RL recipe with quantized rollout generation while preserving downstream performance. FlashRL consists of two main components: | ||||||
|
|
||||||
| - Truncated Importance Sampling (TIS): In scalable RL frameworks, policy model and rollout are typically managed by different libraries/ frameworks (FSDP and vLLM, resp.), which leads to a mismatch between the probability distributions. TIS is a technique that solves the rollout and training mismatch problem by applying a token-level correction factor (based on the importance-sampling ratio) to the policy loss. | ||||||
| - Online Quantization Support: While vLLM has support for inference with quantized weights, it is tricky to use this for RL training. FlashRL also has patches for vLLM to support weight updates for FP8 and Int8 during training. | ||||||
|
|
||||||
|
|
||||||
| FlashRL + SkyRL | ||||||
| --------------- | ||||||
|
|
||||||
| SkyRL now supports an initial integration with FlashRL. Currently, we only support training with `online FP8 quantization <https://docs.vllm.ai/en/v0.9.2/features/quantization/fp8.html#online-dynamic-quantization>`_ in vLLM. You should simply specify ``FLASHRL_CONFIG=fp8_vllm`` in your environment variables and use the ``--extra flashrl`` flag when running the training script. | ||||||
|
|
||||||
|
|
||||||
| .. warning:: | ||||||
|
|
||||||
| FlashRL integration only supports single-turn training at the moment. | ||||||
|
|
||||||
|
|
||||||
| How does it work? | ||||||
| ~~~~~~~~~~~~~~~~~~ | ||||||
|
|
||||||
| We pass `quantization=fp8` flag to the vLLM engine at initialization time. This means that the weights are loaded as usual in half precision and then quantized down to fp8. During training, generations are sampled as usual, and in this case, sampled from quantized weights. Since we use online quantization, the scale factor used for quantizing activations are computed on the fly by vLLM internally. | ||||||
|
|
||||||
| The sampled rollouts are then used to compute the policy loss. We further apply the TIS correction factor to the policy loss and then update the policy model weights. These weights, in half precision, are then synced with the inference engine layer by layer. These are then loaded and quantized down to fp8 similar to how we quantized the weights at initialization. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| Example | ||||||
| -------- | ||||||
|
|
||||||
| We provide two examples for training with FP8 rollouts for DAPO: one for training Qwen2.5-1.5B-Instruct and one for Qwen3-32B. The FlashRL related files are in ``skyrl_train/examples/flash_rl/`` folder. | ||||||
|
|
||||||
|
|
||||||
| .. code-block:: bash | ||||||
| :caption: Training configuration at ``skyrl_train/examples/flash_rl/run_dapo_flashrl.sh`` | ||||||
|
|
||||||
| # path for dataset (.parquet files) containing the prompts and metadata for each question | ||||||
| DATA_DIR="$HOME/data/gsm8k" | ||||||
|
|
||||||
| uv run --isolated --extra flashrl --env-file examples/flash_rl/.env.flashrl -m examples.flash_rl.main_dapo_flashrl \ | ||||||
| ... | ||||||
| trainer.algorithm.use_tis=true \ | ||||||
| trainer.algorithm.tis_imp_ratio_cap=2.0 \ | ||||||
| ... | ||||||
|
|
||||||
| Here, we've configured training to use TIS with the importance sampling ratio cap of 2.0. Note that for making sure the FlashRL patches are applied for vLLM, we use the ``FLASHRL_CONFIG`` env var in ``examples/flash_rl/.env.flashrl``: | ||||||
|
|
||||||
| .. code-block:: bash | ||||||
| :caption: Environment variables at ``examples/flash_rl/.env.flashrl`` | ||||||
|
|
||||||
| FLASHRL_CONFIG=fp8_vllm | ||||||
| ... | ||||||
|
|
||||||
| .. warning:: | ||||||
|
|
||||||
| FlashRL integration is experimental. While generation times can improve for large models with quantization, we've observed that the time spent in weight syncing is much higher with FlashRL for fp8. This negates some of the benefits of fp8 inference. The slowdown is primarily due to slow weight quantization in vLLM's ``process_weights_after_loading`` function and we are actively working on improving this. | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| FLASHRL_LOGGING_LEVEL=DEBUG # optional | ||
| FLASHRL_CONFIG="fp8_vllm" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| from typing import List, Any, Dict, Optional | ||
| import ray | ||
| import vllm | ||
| from skyrl_train.inference_engines.vllm.vllm_engine import VLLMInferenceEngine | ||
| from skyrl_train.inference_engines.ray_wrapped_inference_engine import RayWrappedInferenceEngine | ||
| from ray.util.placement_group import PlacementGroupSchedulingStrategy, placement_group | ||
|
|
||
| from skyrl_train.inference_engines.base import ( | ||
| InferenceEngineInterface, | ||
| ) | ||
|
|
||
|
|
||
| class FlashRLVLLMInferenceEngine(VLLMInferenceEngine): | ||
|
|
||
| def _create_engine(self, *args, **kwargs): | ||
| # apply flashrl's patch just before init | ||
| from vllm.model_executor.layers.patch import apply_patch as apply_flashrl_patch | ||
|
|
||
| apply_flashrl_patch() | ||
|
|
||
| llm = vllm.LLM(*args, **kwargs) | ||
| return llm | ||
|
|
||
|
|
||
| VLLMRayActor = ray.remote(FlashRLVLLMInferenceEngine) | ||
|
|
||
|
|
||
| def create_ray_wrapped_inference_engines_flashrl( | ||
| num_inference_engines: int, | ||
| tensor_parallel_size: int, | ||
| model_dtype: str, | ||
| pretrain: str, | ||
| seed: int, | ||
| vllm_v1_disable_multiproc: bool, | ||
| enable_prefix_caching: bool, | ||
| enforce_eager: bool, | ||
| max_model_len: int, | ||
| shared_pg=None, | ||
| gpu_memory_utilization=None, | ||
| inference_engine_enable_sleep=False, | ||
| async_engine=False, | ||
| max_num_batched_tokens=8192, | ||
| max_num_seqs=1024, | ||
| sampling_params: Optional[Dict[str, Any]] = None, | ||
| tokenizer=None, | ||
| backend="vllm", | ||
| ) -> List[InferenceEngineInterface]: | ||
| """ | ||
| Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface instances. | ||
| """ | ||
| from skyrl_train.utils import ray_noset_visible_devices, get_all_env_variables, get_ray_pg_ready_with_timeout | ||
|
|
||
| assert not async_engine, "`async_engine` is not supported for FlashRL" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to confirm - we can only use the offline engine for flash-rl, so only single turn rollouts?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe worth a clarification in the doc, I didn't realize until i hit this line of code
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah let me add a warning |
||
|
|
||
| if backend != "vllm": | ||
| raise ValueError(f"Unsupported FlashRL backend: {backend}") | ||
|
|
||
| inference_engine_actors = [] | ||
| noset_visible_devices = ray_noset_visible_devices(ray.get(get_all_env_variables.remote())) | ||
| # NOTE: we use the ray backend for tensor parallel size > 1 to explicitly manage resource allocation | ||
| # TODO: we should be able to support mp backend by allocating resources at engine level | ||
| distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray" | ||
| use_hybrid_engine = shared_pg is not None | ||
| num_gpus = int(tensor_parallel_size == 1) | ||
| if use_hybrid_engine and tensor_parallel_size == 1: | ||
| # every worker will use 0.2 GPU, so that we can schedule | ||
| # 2 instances on the same GPUs. | ||
| num_gpus = 0.2 | ||
|
|
||
| if not use_hybrid_engine: | ||
| # Create a big placement group to ensure that all inference engines are packed | ||
| bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_inference_engines * tensor_parallel_size)] | ||
| shared_pg = placement_group(bundles, strategy="PACK") | ||
| get_ray_pg_ready_with_timeout(shared_pg, timeout=30) | ||
|
|
||
| for i in range(num_inference_engines): | ||
| bundle_indices = None | ||
| if tensor_parallel_size > 1: | ||
| bundle_indices = list(range(i * tensor_parallel_size, (i + 1) * tensor_parallel_size)) | ||
|
|
||
| scheduling_strategy = PlacementGroupSchedulingStrategy( | ||
| placement_group=shared_pg, | ||
| placement_group_capture_child_tasks=True, | ||
| placement_group_bundle_index=i * tensor_parallel_size, | ||
| ) | ||
|
|
||
| if backend == "vllm": | ||
|
|
||
| engine = VLLMRayActor.options( | ||
| num_cpus=num_gpus, | ||
| num_gpus=num_gpus, | ||
| scheduling_strategy=scheduling_strategy, | ||
| ).remote( | ||
| model=pretrain, | ||
| enforce_eager=enforce_eager, | ||
| worker_extension_cls="skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap", | ||
| tensor_parallel_size=tensor_parallel_size, | ||
| seed=seed + i, | ||
| distributed_executor_backend=distributed_executor_backend, | ||
| max_model_len=max_model_len, | ||
| enable_prefix_caching=enable_prefix_caching, | ||
| dtype=model_dtype, | ||
| trust_remote_code=True, | ||
| vllm_v1_disable_multiproc=vllm_v1_disable_multiproc, | ||
| gpu_memory_utilization=gpu_memory_utilization, | ||
| bundle_indices=bundle_indices, | ||
| num_gpus=0.2 if use_hybrid_engine else 1, | ||
| enable_sleep_mode=inference_engine_enable_sleep, | ||
| noset_visible_devices=noset_visible_devices, | ||
| max_num_batched_tokens=max_num_batched_tokens, | ||
| max_num_seqs=max_num_seqs, | ||
| sampling_params=sampling_params, | ||
| tokenizer=tokenizer, | ||
| # only need the logprobs for the chosen token if any | ||
| max_logprobs=1, | ||
| ) | ||
|
|
||
| inference_engine_actors.append(engine) | ||
|
|
||
| engines = [RayWrappedInferenceEngine(actor_handle) for actor_handle in inference_engine_actors] | ||
|
|
||
| if inference_engine_enable_sleep: | ||
| sleep_refs = [engine.inference_engine_actor.sleep.remote() for engine in engines] | ||
| ray.get(sleep_refs) | ||
|
|
||
| return engines | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| """ | ||
SumanthRH marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| uv run --isolated --extra vllm -m examples.flash_rl.main_dapo_flashrl | ||
| """ | ||
|
|
||
| import ray | ||
| import os | ||
| import hydra | ||
| import torch | ||
| from typing import List | ||
| from omegaconf import DictConfig | ||
| from skyrl_train.trainer import RayPPOTrainer | ||
| from skyrl_train.utils import initialize_ray | ||
| from skyrl_train.entrypoints.main_base import ( | ||
| BasePPOExp, | ||
| config_dir, | ||
| validate_cfg, | ||
| create_remote_inference_engines_from_config, | ||
| ) | ||
| from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient | ||
| from skyrl_train.inference_engines.utils import get_sampling_params_for_backend | ||
| from skyrl_train.generators.base import GeneratorInterface | ||
| from loguru import logger | ||
| from skyrl_train.generators.base import GeneratorOutput | ||
|
|
||
|
|
||
| def create_ray_wrapped_inference_engines_from_config_flashrl(cfg: DictConfig, colocate_pg, tokenizer): | ||
| from examples.flash_rl.flash_rl_engine import create_ray_wrapped_inference_engines_flashrl | ||
|
|
||
| return create_ray_wrapped_inference_engines_flashrl( | ||
| num_inference_engines=cfg.generator.num_inference_engines, | ||
| tensor_parallel_size=cfg.generator.inference_engine_tensor_parallel_size, | ||
| model_dtype=cfg.generator.model_dtype, | ||
| pretrain=cfg.trainer.policy.model.path, | ||
| seed=cfg.trainer.seed, | ||
| vllm_v1_disable_multiproc=cfg.generator.vllm_v1_disable_multiproc, | ||
| enable_prefix_caching=cfg.generator.enable_prefix_caching, | ||
| enforce_eager=cfg.generator.enforce_eager, | ||
| max_model_len=cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length, | ||
| shared_pg=colocate_pg, | ||
| gpu_memory_utilization=cfg.generator.gpu_memory_utilization, | ||
| inference_engine_enable_sleep=cfg.trainer.placement.colocate_all, | ||
| async_engine=cfg.generator.async_engine, | ||
| max_num_batched_tokens=cfg.generator.max_num_batched_tokens, | ||
| max_num_seqs=cfg.generator.max_num_seqs, | ||
| sampling_params=get_sampling_params_for_backend(cfg.generator.backend, cfg.generator.sampling_params), | ||
| tokenizer=tokenizer, | ||
| backend=cfg.generator.backend, | ||
| ) | ||
|
|
||
|
|
||
| class DAPOTrainer(RayPPOTrainer): | ||
| """ | ||
| Custom trainer for DAPO. | ||
|
|
||
| Overrides the postprocess_generator_output method to additionally apply soft overlong punishment to rewards. | ||
| """ | ||
|
|
||
| @torch.no_grad() | ||
| def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput: | ||
| """ | ||
| Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards. | ||
|
|
||
| Args: | ||
| generator_output: GeneratorOutput | ||
| uids: List[str] | ||
|
|
||
| Returns: | ||
| GeneratorOutput | ||
| """ | ||
| overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len | ||
| overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor | ||
| # modify rewards here | ||
| prompt_token_ids = generator_output["prompt_token_ids"] | ||
| response_ids = generator_output["response_ids"] | ||
| rewards = generator_output["rewards"] | ||
|
|
||
| assert not isinstance(rewards[0], list), "we assume verifiable sequence level rewards here" | ||
|
|
||
| # get the prompt length | ||
| prompt_lengths = [len(prompt) for prompt in prompt_token_ids] | ||
|
|
||
| # get the response length | ||
| response_lengths = [len(response) for response in response_ids] | ||
|
|
||
| # get the max context length | ||
| max_context_length = ( | ||
| self.cfg.generator.max_input_length + self.cfg.generator.sampling_params.max_generate_length | ||
| ) | ||
|
|
||
| # apply soft overlong punishment | ||
| for i, (prompt_length, response_length) in enumerate(zip(prompt_lengths, response_lengths)): | ||
| # max_exceed_length is the beginning of the overlong buffer | ||
| max_exceed_length = max_context_length - overlong_buffer_len - prompt_length | ||
| # if the response is within the overlong buffer, apply the penalty | ||
| if response_length > max_exceed_length and response_length <= max_context_length - prompt_length: | ||
| exceed_length = response_length - max_exceed_length | ||
| penalty = exceed_length / overlong_buffer_len * overlong_buffer_penalty_factor | ||
|
|
||
| rewards[i] -= penalty | ||
| # if the response is outside the overlong buffer, set the reward to 0 | ||
| elif response_length > max_context_length - prompt_length: | ||
| # if self.cfg.generator.apply_overlong_filtering is true, loss masks are already set to 0 for these responses | ||
| rewards[i] = 0.0 | ||
|
|
||
| generator_output["rewards"] = rewards | ||
|
|
||
| # use base class impl for metrics and per-token reward conversion | ||
| return super().postprocess_generator_output(generator_output, uids) | ||
|
|
||
|
|
||
| class DAPOExp(BasePPOExp): | ||
| def get_trainer(self, *args, **kwargs): | ||
| return DAPOTrainer(*args, **kwargs) | ||
|
|
||
| def _setup_trainer(self): | ||
| """Setup and return the trainer. | ||
|
|
||
| Instantiates the trainer and all the associated models for training. | ||
|
|
||
| Returns: | ||
| RayPPOTrainer: The trainer. | ||
| """ | ||
| logger.info(self.get_cfg_as_str(self.cfg)) | ||
| os.makedirs(self.cfg.trainer.export_path, exist_ok=True) | ||
| os.makedirs(self.cfg.trainer.ckpt_path, exist_ok=True) | ||
|
|
||
| if self.cfg.trainer.strategy == "deepspeed": | ||
| from skyrl_train.workers.deepspeed.deepspeed_worker import ( | ||
| PolicyWorker, | ||
| CriticWorker, | ||
| RefWorker, | ||
| RewardWorker, | ||
| ) | ||
| elif self.cfg.trainer.strategy in ("fsdp", "fsdp2"): | ||
| from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker, CriticWorker, RefWorker, RewardWorker | ||
| else: | ||
| raise ValueError(f"Unknown strategy type: {self.cfg.trainer.strategy}") | ||
|
|
||
| # NOTE (sumanthrh): Instantiate tracker before trainer init. | ||
| # We have custom validation before this step to give better error messages. | ||
| tracker = self.get_tracker() | ||
|
|
||
| tokenizer = self.tokenizer | ||
| if self.cfg.generator.run_engines_locally: | ||
| inference_engines = create_ray_wrapped_inference_engines_from_config_flashrl( | ||
SumanthRH marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.cfg, self.colocate_pg, tokenizer | ||
| ) | ||
| else: | ||
| inference_engines = create_remote_inference_engines_from_config(self.cfg) | ||
|
|
||
| inference_engine_client = InferenceEngineClient(inference_engines) | ||
|
|
||
| generator: GeneratorInterface = self.get_generator(self.cfg, tokenizer, inference_engine_client) | ||
|
|
||
| trainer = self.get_trainer( | ||
| cfg=self.cfg, | ||
| tracker=tracker, | ||
| tokenizer=tokenizer, | ||
| train_dataset=self.train_dataset, | ||
| eval_dataset=self.eval_dataset, | ||
| inference_engine_client=inference_engine_client, | ||
| generator=generator, | ||
| colocate_pg=self.colocate_pg, | ||
| ) | ||
|
|
||
| # Build the models | ||
| trainer.build_models(PolicyWorker, CriticWorker, RefWorker, RewardWorker) | ||
| return trainer | ||
|
|
||
|
|
||
| @ray.remote(num_cpus=1) | ||
| def skyrl_entrypoint(cfg: DictConfig): | ||
|
|
||
| exp = DAPOExp(cfg) | ||
| exp.run() | ||
|
|
||
|
|
||
| @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) | ||
| def main(cfg: DictConfig) -> None: | ||
| # validate the arguments | ||
| validate_cfg(cfg) | ||
|
|
||
| if not cfg.generator.run_engines_locally: | ||
| raise ValueError("FlashRL only supports colocated training.") | ||
|
|
||
| initialize_ray(cfg) | ||
| ray.get(skyrl_entrypoint.remote(cfg)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.