diff --git a/skyrl-train/.env.llm_judge b/skyrl-train/.env.llm_judge new file mode 100644 index 0000000000..3b07cf7a24 --- /dev/null +++ b/skyrl-train/.env.llm_judge @@ -0,0 +1,3 @@ +OPENAI_API_KEY="" +# optionally, enter wandb if logging with wandb +# WANDB_API_KEY= \ No newline at end of file diff --git a/skyrl-train/examples/llm_as_a_judge/gsm8k_dataset_judge.py b/skyrl-train/examples/llm_as_a_judge/gsm8k_dataset_judge.py new file mode 100644 index 0000000000..a50d72fdce --- /dev/null +++ b/skyrl-train/examples/llm_as_a_judge/gsm8k_dataset_judge.py @@ -0,0 +1,90 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the GSM8k dataset to parquet format +""" + +import argparse +import re +import os + +import datasets + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", default="~/data/gsm8k_llm_judge") + + args = parser.parse_args() + + args.output_dir = os.path.expanduser(args.output_dir) + + data_source = "openai/gsm8k" + + dataset = datasets.load_dataset(data_source, "main") + + train_dataset = dataset["train"] + val_dataset = dataset["test"] + + instruction_following = 'Let\'s think step by step and output the final answer after "####".' + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question_raw = example.pop("question") + + question = question_raw + " " + instruction_following + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "data_source": data_source, + "prompt": [ + { + "role": "user", + "content": question, + } + ], + # TODO: just repeating the full data preprocess script for a single env change isn't very convenient. + "env_class": "llm_as_a_judge", + "reward_spec": { + "method": "rule", + "ground_truth": solution, + }, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + val_dataset = val_dataset.map(function=make_map_fn("test"), with_indices=True) + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + train_dataset.to_parquet(os.path.join(output_dir, "train.parquet")) + val_dataset.to_parquet(os.path.join(output_dir, "validation.parquet")) diff --git a/skyrl-train/examples/llm_as_a_judge/llm_judge_env.py b/skyrl-train/examples/llm_as_a_judge/llm_judge_env.py new file mode 100644 index 0000000000..a006f9c7a9 --- /dev/null +++ b/skyrl-train/examples/llm_as_a_judge/llm_judge_env.py @@ -0,0 +1,84 @@ +from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput +from typing import Any +from typing import Dict +from omegaconf import DictConfig +from openai import OpenAI +import os +import re + +PROMPT = """ +You are a strict math evaluation assistant. + +Compare the following **gold** and **predicted** math solutions. Your job is to determine if the predicted solution is mathematically correct and if the predicted solution ends with a line of the form: + +#### + +You must only give a score of "1" if: +- The final line of the predicted solution **ends with `#### `**, and +- The number **matches the final answer in the gold solution** exactly. + +Instructions: +- You may provide internal reasoning or explanation before giving your final judgment. +- Your final judgment must appear as a separate line at the end of your response, in the format: + +### Final Score: 1 + +or + +### Final Score: 0 + +Do not include any explanation after the final score. +""" + + +class GSM8kLLMJudgeEnv(BaseTextEnv): + """ + Example implementtion of GSM8k environment with LLM as judge. + + Use LLM as judge to evaluate the answer similarity with the ground truth. + """ + + def __init__(self, env_config: DictConfig, extras: Dict[str, Any] = {}): + super().__init__() + + assert "reward_spec" in extras, "reward_spec field is required" + assert "ground_truth" in extras["reward_spec"], "ground_truth is required in reward_spec field" + self.ground_truth = extras["reward_spec"]["ground_truth"] + + # Set up OpenAI client + openai_api_key = os.getenv("OPENAI_API_KEY") + if openai_api_key is None: + raise ValueError("`OPENAI_API_KEY` must be set for Llm as a judge env") + self.llm_judge_client = OpenAI(base_url=env_config.base_url, api_key=openai_api_key) + self.model = env_config.model + + def _get_reward(self, action: str) -> float: + message = PROMPT + f"\n\nGOLD SOLUTION:\n{self.ground_truth}\n\nPREDICTED SOLUTION:\n{action}\n\nAnswer:" + + try: + response = self.llm_judge_client.chat.completions.create( + model=self.model, messages=[{"role": "user", "content": message}] + ) + reply = response.choices[0].message.content.strip() + + # Try to parse score from "### Final Score: x" + match = re.search(r"### Final Score:\s*([01](?:\.0)?)", reply) + if match: + return float(match.group(1)) + + # Fallback: raw "1" or "0" + if reply.strip() in {"1", "0"}: + return float(reply.strip()) + + print(f"Unrecognized reward output: {reply}") + return 0.0 + + except Exception as e: + print(f"LLM Judge error: {type(e).__name__}: {e}") + return 0.0 + + def step(self, action: str) -> BaseTextEnvStepOutput: + done = True + reward = self._get_reward(action) + + return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={}) diff --git a/skyrl-train/examples/llm_as_a_judge/main_llm_judge.py b/skyrl-train/examples/llm_as_a_judge/main_llm_judge.py new file mode 100644 index 0000000000..c3e8050d95 --- /dev/null +++ b/skyrl-train/examples/llm_as_a_judge/main_llm_judge.py @@ -0,0 +1,36 @@ +""" +uv run --isolated --extra vllm -m examples.llm_as_a_judge.main_llm_judge +""" + +import ray +import hydra +from omegaconf import DictConfig +from skyrl_train.utils import initialize_ray +from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg +from skyrl_gym.envs import register + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg: DictConfig): + # Register the llm_as_a_judge environment inside the entrypoint task (no need to modify the skyrl-gym package). + register( + id="llm_as_a_judge", + entry_point="examples.llm_as_a_judge.llm_judge_env:GSM8kLLMJudgeEnv", + ) + + # make sure that the training loop is not run on the head node. + exp = BasePPOExp(cfg) + exp.run() + + +@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) +def main(cfg: DictConfig) -> None: + # validate the arguments + validate_cfg(cfg) + + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/skyrl-train/examples/llm_as_a_judge/run_llm_judge.sh b/skyrl-train/examples/llm_as_a_judge/run_llm_judge.sh new file mode 100644 index 0000000000..d3a468cbc6 --- /dev/null +++ b/skyrl-train/examples/llm_as_a_judge/run_llm_judge.sh @@ -0,0 +1,57 @@ +set -x + +# Colocated GRPO training+generation for Qwen2.5-Coder-1.5B-Instruct on GSM8k dataset. +# Uses 1 node with 4 GPUs. +# uv run examples/llm_as_a_judge/gsm8k_dataset_judge.py --output_dir $HOME/data/gsm8k_llm_judge +# add OPENAI_API_KEY and WANDB_API_KEY to .env.llm_judge +# bash examples/llm_as_a_judge/run_llm_judge.sh + +DATA_DIR="$HOME/data/gsm8k_llm_judge" +CKPT_PATH="$HOME/ckpts/llm_judge" + +NUM_GPUS=4 +NUM_INFERENCE_ENGINES=4 +TP_SIZE=1 +LOGGER=wandb + +# We use a smaller batch size here for demonstration +uv run --isolated --extra vllm --env-file .env.llm_judge -m examples.llm_as_a_judge.main_llm_judge \ + data.train_data="['$DATA_DIR/train.parquet']" \ + data.val_data="['$DATA_DIR/validation.parquet']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ + trainer.placement.colocate_all=true \ + trainer.strategy=fsdp2 \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ + generator.num_inference_engines=$NUM_INFERENCE_ENGINES \ + generator.inference_engine_tensor_parallel_size=$TP_SIZE \ + trainer.epochs=20 \ + trainer.eval_batch_size=32 \ + trainer.eval_before_train=false \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=32 \ + trainer.policy_mini_batch_size=32 \ + trainer.micro_forward_batch_size_per_gpu=40 \ + trainer.micro_train_batch_size_per_gpu=40 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=512 \ + generator.sampling_params.max_generate_length=1024 \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=true \ + generator.backend=vllm \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=true \ + generator.n_samples_per_prompt=5 \ + generator.gpu_memory_utilization=0.8 \ + trainer.logger="$LOGGER" \ + trainer.project_name="gsm8k" \ + trainer.run_name="gsm8k_llm_as_a_judge" \ + trainer.resume_mode=null \ + trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ + environment.env_class=llm_as_a_judge \ + environment.skyrl_gym.llm_as_a_judge.model="gpt-4o-mini" \ + $@ \ No newline at end of file diff --git a/skyrl-train/skyrl_train/config/skyrl_gym_config/default.yaml b/skyrl-train/skyrl_train/config/skyrl_gym_config/default.yaml index 37349addfd..a94985f1e5 100644 --- a/skyrl-train/skyrl_train/config/skyrl_gym_config/default.yaml +++ b/skyrl-train/skyrl_train/config/skyrl_gym_config/default.yaml @@ -1,10 +1,16 @@ # @package environment.skyrl_gym # number of background workers for env step calls. Set to 0 to disable background workers. max_env_workers: 32 + text2sql: db_path: "/home/ray/default/sql_data" + +llm_as_a_judge: + model: "gpt-4o-mini" + base_url: null # or a local endpoint: http://localhost:8000/v1 + search: log_requests: false search_url: "http://127.0.0.1:8000/retrieve" topk: 3 - timeout: 30 \ No newline at end of file + timeout: 30