|
| 1 | +# OpenEnv Integration for Training LLMs with Environments |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is a framework from Meta to integrate external environments with RL training loops. It provides [Gymnasium-style APIs](https://gymnasium.farama.org) (`reset()`, `step()`, `state()`) and a simple HTTP protocol for interacting with environments running as Docker containers. You can find OpenEnv environments on the Hugging Face Hub under dedicated [orgs](https://huggingface.co/openenv). |
| 6 | + |
| 7 | +[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is an open-source framework from Meta's PyTorch team for defining, deploying and interacting with environments in RL/agentic workflows. It offers [Gymnasium-style APIs](https://gymnasium.farama.org) (e.g., `reset()` and `step()`) to interface with environments in a standard manner, and supports running these environments as backend servers (for example via HTTP or containerised execution). A collection of ready-to-use OpenEnv environments is available on the [Hugging Face Hub](https://huggingface.co/collections/openenv/environment-hub). |
| 8 | + |
| 9 | +Here, we’ll focus on the **integration of OpenEnv with TRL**, but check out the above resources to learn more about them. |
| 10 | + |
| 11 | +## Installation |
| 12 | + |
| 13 | +To use OpenEnv with TRL, install the framework: |
| 14 | + |
| 15 | +```bash |
| 16 | +pip install git+https://github.com/meta-pytorch/OpenEnv.git |
| 17 | +``` |
| 18 | + |
| 19 | +## Using `rollout_func` with OpenEnv environments |
| 20 | + |
| 21 | +TRL's [`GRPOTrainer`] supports _custom rollout logic_ through the `rollout_func` argument. This lets you override the trainer's default text-generation loop and directly interact with OpenEnv environments — for example, to compute environment-based rewards instead of purely model-based ones. |
| 22 | + |
| 23 | +### Rollout Function Signature |
| 24 | + |
| 25 | +A rollout function must have the following signature: |
| 26 | + |
| 27 | +```python |
| 28 | +def rollout_func( |
| 29 | + prompts: list[str], |
| 30 | + args: GRPOConfig, |
| 31 | + processing_class |
| 32 | +) -> dict[str, list]: |
| 33 | + """ |
| 34 | + Custom rollout function for generation and reward computation. |
| 35 | +
|
| 36 | + Args: |
| 37 | + prompts: List of prompts to generate from |
| 38 | + args: GRPOConfig containing sampling parameters (temperature, top_p, etc.) |
| 39 | + processing_class: Tokenizer/processor for encoding/decoding |
| 40 | +
|
| 41 | + Returns: |
| 42 | + Dictionary containing: |
| 43 | + - prompt_ids: List of token IDs for each prompt |
| 44 | + - completion_ids: List of token IDs for each completion |
| 45 | + - logprobs: List of log probabilities for each token |
| 46 | + - Any additional fields are forwarded to reward functions as kwargs |
| 47 | + """ |
| 48 | + pass |
| 49 | +``` |
| 50 | + |
| 51 | +> [!NOTE] |
| 52 | +> Any extra fields in the returned dictionary (beyond the required three) are automatically forwarded to your reward functions. This makes it easy to propagate signals such as environment rewards or auxiliary metrics from the rollout step. |
| 53 | +
|
| 54 | +### Integration pattern |
| 55 | + |
| 56 | +The typical pattern when combining OpenEnv with TRL looks like this: |
| 57 | + |
| 58 | +1. Start or connect to an OpenEnv environment (e.g., an HTTP endpoint or Dockerized env). |
| 59 | +2. Generate completions from your model — for example, via a vLLM inference server (`use_vllm=True`, `vllm_mode="server"`). |
| 60 | +3. Step through the environment using each completion to compute rewards or metrics. |
| 61 | +4. Add environment results (e.g., `env_reward`) to the rollout result dict. |
| 62 | +5. Access those rewards inside your reward function via `**kwargs`. |
| 63 | + |
| 64 | +By using OpenEnv in this loop, you can: |
| 65 | + |
| 66 | +* Train with realistic or interactive feedback (not just static reward functions). |
| 67 | +* Plug in custom simulators, web APIs, or evaluators as environments. |
| 68 | +* Pass structured reward signals back into RL training seamlessly. |
| 69 | + |
| 70 | +## A simple example |
| 71 | + |
| 72 | +The [echo.py](../../examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards: |
| 73 | + |
| 74 | +```python |
| 75 | +from envs.echo_env import EchoEnv, EchoAction |
| 76 | +from trl import GRPOConfig, GRPOTrainer |
| 77 | + |
| 78 | +# Create HTTP client for Echo Environment |
| 79 | +client = EchoEnv.from_docker_image("echo-env:latest") |
| 80 | + |
| 81 | +def rollout_func(prompts, args, processing_class): |
| 82 | + # 1. Generate completions via vLLM inference server (running on port 8000) |
| 83 | + payload = { |
| 84 | + "prompts": prompts, |
| 85 | + "n": args.num_generations, |
| 86 | + "temperature": args.temperature, |
| 87 | + "max_tokens": args.max_completion_length, |
| 88 | + } |
| 89 | + response = requests.post("http://0.0.0.0:8000/generate/", json=payload) |
| 90 | + result = response.json() |
| 91 | + |
| 92 | + completions_text = processing_class.batch_decode( |
| 93 | + result["completion_ids"], |
| 94 | + skip_special_tokens=True |
| 95 | + ) |
| 96 | + |
| 97 | + # 2. Step through the environment to get rewards |
| 98 | + client.reset() |
| 99 | + env_rewards = [] |
| 100 | + for msg in completions_text: |
| 101 | + env_result = client.step(EchoAction(message=msg)) |
| 102 | + env_rewards.append(env_result.reward) |
| 103 | + |
| 104 | + # 3. Add environment rewards as extra field |
| 105 | + result["env_reward"] = env_rewards |
| 106 | + return result |
| 107 | + |
| 108 | +def reward_from_env(completions, **kwargs): |
| 109 | + """Extract environment rewards passed via rollout_func kwargs.""" |
| 110 | + env_rewards = kwargs.get("env_reward", []) |
| 111 | + return [float(reward) for reward in env_rewards] if env_rewards else [0.0] * len(completions) |
| 112 | + |
| 113 | +dataset = Dataset.from_dict({"prompt": ["You are an AI that interacts with an *Echo* environment. Word to echo:"] * 64}) |
| 114 | + |
| 115 | +# Setup trainer with custom rollout |
| 116 | +trainer = GRPOTrainer( |
| 117 | + model="Qwen/Qwen2.5-0.5B-Instruct", |
| 118 | + reward_funcs=reward_from_env, |
| 119 | + train_dataset=dataset, |
| 120 | + rollout_func=rollout_func, # Use custom rollout |
| 121 | + args=GRPOConfig( |
| 122 | + vllm_mode="server", |
| 123 | + use_vllm=True, |
| 124 | + num_train_epochs=1, |
| 125 | + num_generations=8, |
| 126 | + max_completion_length=2048, |
| 127 | + per_device_train_batch_size=8, |
| 128 | + gradient_accumulation_steps=4, |
| 129 | + ), |
| 130 | +) |
| 131 | +trainer.train() |
| 132 | +``` |
| 133 | + |
| 134 | +That's it! Now that you’ve seen the full example, let’s unpack how the main pieces fit together. |
| 135 | + |
| 136 | +1. **Environment Client:** `EchoEnv` implements an HTTP interface to interact with the environment server. |
| 137 | +2. **Custom rollout:** The `rollout_func` generates completions and steps through the environment to collect rewards. |
| 138 | +3. **Extra fields:** The rollout adds `env_reward` to the result dictionary, which is automatically passed to reward functions. |
| 139 | +4. **Reward function:** Extracts `env_reward` from `kwargs` to apply environment-computed rewards during training. |
| 140 | + |
| 141 | +> [!WARNING] |
| 142 | +> The `rollout_func` is currently only supported when using vLLM in server mode (`use_vllm=True`, `vllm_mode="server"`). |
| 143 | +
|
| 144 | +### Running the Example |
| 145 | + |
| 146 | +The example requires two GPUs: |
| 147 | + |
| 148 | +```bash |
| 149 | +# Terminal 1: Start vLLM inference server |
| 150 | +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 |
| 151 | + |
| 152 | +# Terminal 2: Run GRPO training with OpenEnv |
| 153 | +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py |
| 154 | +``` |
| 155 | + |
| 156 | +To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md). |
| 157 | + |
| 158 | +## Another example: Catch |
| 159 | + |
| 160 | +The [catch.py](../../examples/scripts/openenv/catch.py) script demonstrates training an LLM to play the Catch environment from OpenEnv. |
| 161 | +In this example, the catch environment is a simple 10×5 grid game where a ball falls from the top and you control a paddle at the bottom. Move left, right, or stay to catch the ball for +1 reward or miss it for –1. |
| 162 | + |
| 163 | +```txt |
| 164 | +· · ● · · |
| 165 | +· · · · · |
| 166 | +· · · · · |
| 167 | +· · · · · |
| 168 | +· · · · · |
| 169 | +· · · · · |
| 170 | +· · · · · |
| 171 | +· · · · · |
| 172 | +· · · · · |
| 173 | +· · █ · · |
| 174 | +``` |
| 175 | + |
| 176 | +The model is prompted with a description of the environment and the current state, and trained to output actions to maximize the environment reward. Below is the reward curve from training: |
| 177 | + |
| 178 | +<iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe> |
0 commit comments