Skip to content

Commit 2819a8f

Browse files
lewtunkashifqgallouedec
authored
🕹️ Add rollout function for OpenEnv integration (#4310)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent e1c87e3 commit 2819a8f

File tree

5 files changed

+677
-11
lines changed

5 files changed

+677
-11
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,6 @@
109109
- sections:
110110
- local: bco_trainer
111111
title: BCO
112+
- local: openenv
113+
title: OpenEnv Integration
112114
title: Experimental

docs/source/openenv.md

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# OpenEnv Integration for Training LLMs with Environments
2+
3+
## Overview
4+
5+
[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is a framework from Meta to integrate external environments with RL training loops. It provides [Gymnasium-style APIs](https://gymnasium.farama.org) (`reset()`, `step()`, `state()`) and a simple HTTP protocol for interacting with environments running as Docker containers. You can find OpenEnv environments on the Hugging Face Hub under dedicated [orgs](https://huggingface.co/openenv).
6+
7+
[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is an open-source framework from Meta's PyTorch team for defining, deploying and interacting with environments in RL/agentic workflows. It offers [Gymnasium-style APIs](https://gymnasium.farama.org) (e.g., `reset()` and `step()`) to interface with environments in a standard manner, and supports running these environments as backend servers (for example via HTTP or containerised execution). A collection of ready-to-use OpenEnv environments is available on the [Hugging Face Hub](https://huggingface.co/collections/openenv/environment-hub).
8+
9+
Here, we’ll focus on the **integration of OpenEnv with TRL**, but check out the above resources to learn more about them.
10+
11+
## Installation
12+
13+
To use OpenEnv with TRL, install the framework:
14+
15+
```bash
16+
pip install git+https://github.com/meta-pytorch/OpenEnv.git
17+
```
18+
19+
## Using `rollout_func` with OpenEnv environments
20+
21+
TRL's [`GRPOTrainer`] supports _custom rollout logic_ through the `rollout_func` argument. This lets you override the trainer's default text-generation loop and directly interact with OpenEnv environments — for example, to compute environment-based rewards instead of purely model-based ones.
22+
23+
### Rollout Function Signature
24+
25+
A rollout function must have the following signature:
26+
27+
```python
28+
def rollout_func(
29+
prompts: list[str],
30+
args: GRPOConfig,
31+
processing_class
32+
) -> dict[str, list]:
33+
"""
34+
Custom rollout function for generation and reward computation.
35+
36+
Args:
37+
prompts: List of prompts to generate from
38+
args: GRPOConfig containing sampling parameters (temperature, top_p, etc.)
39+
processing_class: Tokenizer/processor for encoding/decoding
40+
41+
Returns:
42+
Dictionary containing:
43+
- prompt_ids: List of token IDs for each prompt
44+
- completion_ids: List of token IDs for each completion
45+
- logprobs: List of log probabilities for each token
46+
- Any additional fields are forwarded to reward functions as kwargs
47+
"""
48+
pass
49+
```
50+
51+
> [!NOTE]
52+
> Any extra fields in the returned dictionary (beyond the required three) are automatically forwarded to your reward functions. This makes it easy to propagate signals such as environment rewards or auxiliary metrics from the rollout step.
53+
54+
### Integration pattern
55+
56+
The typical pattern when combining OpenEnv with TRL looks like this:
57+
58+
1. Start or connect to an OpenEnv environment (e.g., an HTTP endpoint or Dockerized env).
59+
2. Generate completions from your model — for example, via a vLLM inference server (`use_vllm=True`, `vllm_mode="server"`).
60+
3. Step through the environment using each completion to compute rewards or metrics.
61+
4. Add environment results (e.g., `env_reward`) to the rollout result dict.
62+
5. Access those rewards inside your reward function via `**kwargs`.
63+
64+
By using OpenEnv in this loop, you can:
65+
66+
* Train with realistic or interactive feedback (not just static reward functions).
67+
* Plug in custom simulators, web APIs, or evaluators as environments.
68+
* Pass structured reward signals back into RL training seamlessly.
69+
70+
## A simple example
71+
72+
The [echo.py](../../examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards:
73+
74+
```python
75+
from envs.echo_env import EchoEnv, EchoAction
76+
from trl import GRPOConfig, GRPOTrainer
77+
78+
# Create HTTP client for Echo Environment
79+
client = EchoEnv.from_docker_image("echo-env:latest")
80+
81+
def rollout_func(prompts, args, processing_class):
82+
# 1. Generate completions via vLLM inference server (running on port 8000)
83+
payload = {
84+
"prompts": prompts,
85+
"n": args.num_generations,
86+
"temperature": args.temperature,
87+
"max_tokens": args.max_completion_length,
88+
}
89+
response = requests.post("http://0.0.0.0:8000/generate/", json=payload)
90+
result = response.json()
91+
92+
completions_text = processing_class.batch_decode(
93+
result["completion_ids"],
94+
skip_special_tokens=True
95+
)
96+
97+
# 2. Step through the environment to get rewards
98+
client.reset()
99+
env_rewards = []
100+
for msg in completions_text:
101+
env_result = client.step(EchoAction(message=msg))
102+
env_rewards.append(env_result.reward)
103+
104+
# 3. Add environment rewards as extra field
105+
result["env_reward"] = env_rewards
106+
return result
107+
108+
def reward_from_env(completions, **kwargs):
109+
"""Extract environment rewards passed via rollout_func kwargs."""
110+
env_rewards = kwargs.get("env_reward", [])
111+
return [float(reward) for reward in env_rewards] if env_rewards else [0.0] * len(completions)
112+
113+
dataset = Dataset.from_dict({"prompt": ["You are an AI that interacts with an *Echo* environment. Word to echo:"] * 64})
114+
115+
# Setup trainer with custom rollout
116+
trainer = GRPOTrainer(
117+
model="Qwen/Qwen2.5-0.5B-Instruct",
118+
reward_funcs=reward_from_env,
119+
train_dataset=dataset,
120+
rollout_func=rollout_func, # Use custom rollout
121+
args=GRPOConfig(
122+
vllm_mode="server",
123+
use_vllm=True,
124+
num_train_epochs=1,
125+
num_generations=8,
126+
max_completion_length=2048,
127+
per_device_train_batch_size=8,
128+
gradient_accumulation_steps=4,
129+
),
130+
)
131+
trainer.train()
132+
```
133+
134+
That's it! Now that you’ve seen the full example, let’s unpack how the main pieces fit together.
135+
136+
1. **Environment Client:** `EchoEnv` implements an HTTP interface to interact with the environment server.
137+
2. **Custom rollout:** The `rollout_func` generates completions and steps through the environment to collect rewards.
138+
3. **Extra fields:** The rollout adds `env_reward` to the result dictionary, which is automatically passed to reward functions.
139+
4. **Reward function:** Extracts `env_reward` from `kwargs` to apply environment-computed rewards during training.
140+
141+
> [!WARNING]
142+
> The `rollout_func` is currently only supported when using vLLM in server mode (`use_vllm=True`, `vllm_mode="server"`).
143+
144+
### Running the Example
145+
146+
The example requires two GPUs:
147+
148+
```bash
149+
# Terminal 1: Start vLLM inference server
150+
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
151+
152+
# Terminal 2: Run GRPO training with OpenEnv
153+
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
154+
```
155+
156+
To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md).
157+
158+
## Another example: Catch
159+
160+
The [catch.py](../../examples/scripts/openenv/catch.py) script demonstrates training an LLM to play the Catch environment from OpenEnv.
161+
In this example, the catch environment is a simple 10×5 grid game where a ball falls from the top and you control a paddle at the bottom. Move left, right, or stay to catch the ball for +1 reward or miss it for –1.
162+
163+
```txt
164+
· · ● · ·
165+
· · · · ·
166+
· · · · ·
167+
· · · · ·
168+
· · · · ·
169+
· · · · ·
170+
· · · · ·
171+
· · · · ·
172+
· · · · ·
173+
· · █ · ·
174+
```
175+
176+
The model is prompted with a description of the environment and the current state, and trained to output actions to maximize the environment reward. Below is the reward curve from training:
177+
178+
<iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>

0 commit comments

Comments
 (0)