Skip to content
22 changes: 10 additions & 12 deletions skyrl-gym/skyrl_gym/envs/search/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ def _is_done(self, action: str) -> bool:
return True
return "<answer>" in action and "</answer>" in action

def _postprocess_action(self, action: str) -> str:
if "</search>" in action:
return action.split("</search>")[0] + "</search>"
elif "</answer>" in action:
return action.split("</answer>")[0] + "</answer>"
else:
return action
def _validate_action(self, action: str):
stop_tags = ["</search>", "</answer>"]
for tag in stop_tags:
if tag in action:
assert action.split(tag, 1)[1] == "", (
f"{tag} detected in the response but it is not the last string generated. "
f"Use {stop_tags} as stop strings in the configuration."
)

def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -> str:
tool_output = super()._execute_tool(tool_group_name, tool_name, tool_input)
Expand All @@ -71,17 +72,15 @@ def _execute_tool(self, tool_group_name: str, tool_name: str, tool_input: Any) -

def step(self, action: str) -> BaseTextEnvStepOutput:
self.turns += 1
action = self._postprocess_action(action)
self._validate_action(action)
self.chat_history.append({"role": "assistant", "content": action})

error = None
done = self._is_done(action)
reward = self._get_reward(action, done)

if done:
return BaseTextEnvStepOutput(
observations=[], reward=reward, done=done, metadata={}, postprocessed_action=action
)
return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={})

try:
query = self._parse_action(action)
Expand Down Expand Up @@ -114,5 +113,4 @@ def step(self, action: str) -> BaseTextEnvStepOutput:
reward=reward,
done=done,
metadata=info,
postprocessed_action=action,
)
22 changes: 10 additions & 12 deletions skyrl-gym/skyrl_gym/envs/sql/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,26 @@ def _is_done(self, action: str) -> bool:
return True
return "<solution>" in action and "</solution>" in action

def _postprocess_action(self, action: str) -> str:
if "</sql>" in action:
return action.split("</sql>")[0] + "</sql>"
elif "</solution>" in action:
return action.split("</solution>")[0] + "</solution>"
else:
return action
def _validate_action(self, action: str):
stop_tags = ["</sql>", "</solution>"]
for tag in stop_tags:
if tag in action:
assert action.split(tag, 1)[1] == "", (
f"{tag} detected in the response but it is not the last string generated. "
f"Use {stop_tags} as stop strings in the configuration."
)

def step(self, action: str) -> BaseTextEnvStepOutput:
self.turns += 1
action = self._postprocess_action(action)
self._validate_action(action)
self.chat_history.append({"role": "assistant", "content": action})

error = None
done = self._is_done(action)
reward = self._get_reward(action, done)

if done:
return BaseTextEnvStepOutput(
observations=[], reward=reward, done=done, metadata={}, postprocessed_action=action
)
return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={})

try:
tool_group_name, tool_name, tool_input = self._parse_action(action)
Expand Down Expand Up @@ -140,5 +139,4 @@ def step(self, action: str) -> BaseTextEnvStepOutput:
reward=reward,
done=done,
metadata=info,
postprocessed_action=action,
)
32 changes: 0 additions & 32 deletions skyrl-gym/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,36 +132,6 @@ def test_parse_action(search_env, action, expected_input):
assert query == expected_input


# =============================================================================
# ACTION POSTPROCESSING FUNCTIONALITY TESTS
# =============================================================================


@pytest.mark.parametrize(
"action, expected",
[
# Search with extra content after closing tag
("<search>Query</search> extra content", "<search>Query</search>"),
# Answer with extra content after closing tag
("<answer>Answer</answer> extra content", "<answer>Answer</answer>"),
# Both search and answer tags
("<search>Query</search> <answer>Answer</answer> extra", "<search>Query</search>"),
# Only search tag (no extra content)
("<search>Query</search>", "<search>Query</search>"),
# Only answer tag (no extra content)
("<answer>Answer</answer>", "<answer>Answer</answer>"),
# No special tags
("Just plain text", "Just plain text"),
# end tag before start tag
("</search><search>Query</search>", "</search>"),
],
)
def test_postprocess_action(search_env, action, expected):
"""Test action postprocessing."""
result = search_env._postprocess_action(action)
assert result == expected


# =============================================================================
# EPISODE TERMINATION CONDITIONS TESTS
# =============================================================================
Expand Down Expand Up @@ -313,8 +283,6 @@ def test_invalid_search_parsing(search_env, mock_search_api):
("<answer>emmanuel macron</answer>", {"target": "Emmanuel Macron"}, 1.0, True),
# Answer without articles
("<answer>Emmanuel Macron</answer>", {"target": "The Emmanuel Macron"}, 1.0, True),
# Multiple answer tags (should use first one)
("<answer>Wrong</answer> <answer>Emmanuel Macron</answer>", {"target": "Emmanuel Macron"}, 0.0, True),
# No answer tag
("Just text without answer tag", {"target": "Emmanuel Macron"}, 0.0, False),
],
Expand Down
5 changes: 5 additions & 0 deletions skyrl-train/docs/examples/multi_turn_text2sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Now that we have our dataset and database files, let's walk through the some of
#### generation sampling params (relevant to algorithm correctness)
generator.sampling_params.temperature=0.6 \
generator.sampling_params.top_p=0.95 \
generator.sampling_params.stop='["</sql>", "</solution>"]' \
generator.eval_sampling_params.stop='["</sql>", "</solution>"]' \

#### training configuration
trainer.policy.optimizer_config.lr=1.0e-6 \
Expand All @@ -146,6 +148,9 @@ Now that we have our dataset and database files, let's walk through the some of
- Chat templating and loss masking for multi-turn conversations are handled by the ``SkyRLGymGenerator`` class.

- In the above example, we set ``use_conversation_multi_turn=false`` to enforce that the multi-turn conversation is formatted as a single assistant response.
- We also set ``stop='["</sql>", "</solution>"]'`` for both ``sampling_params`` and ``eval_sampling_params`` as a part
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be best for our Generator to handle this case, and to link that as an example in the doc.

Currently, both the search and the sql scripts use use_conversation_multi_turn=False, but technically after #123 , these should be able to do use_conversation_multi_turn=True with limited effect on performance.

Ideally, I'd like SkyGymGenerator to handle this case in this PR itself. Could you modify SkyGymGenerator to support use_conversation_multi_turn and custom stop tokens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes makes sense. I'll add a flag to toggle it and add a multi-turn search-R1 script

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be addressed with flag append_eos_token_after_stop_str_in_multi_turn. Updated PR description and a new unit test.

Could you take another look when you get a chance? Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm, just realized sampling_params is None in the generator level. Will further fix

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SumanthRH ready again!

of the training recipe.
- If you are using ``generator.use_conversation_multi_turn=true``, you might want to append an EOS token ID to the end of the response after these stop strings to adhere to the model's behavior (i.e. ending generation with an EOS token ID rather than say ``</solution>``). This can be done by setting ``generator.append_eos_token_after_stop_str_in_multi_turn=true`` in the generator config. The full script is available in `examples/text_to_sql/run_skyrl_sql_conversation_format.sh`.
- If you want to use a conversation-based format, you can set ``use_conversation_multi_turn=true`` and the model will generate a separate assistant response for each turn. This is supported only with ``backend="vllm"`` as of now.
- See :code_link:`skyrl_train/generators/skyrl_gym_generator.py` for more details on both options!

Expand Down
12 changes: 12 additions & 0 deletions skyrl-train/docs/examples/search.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Let's walk through configuration for running GRPO to train a 4-turn search agent
generator.use_conversation_multi_turn=false \
generator.sampling_params.temperature=1.0 \
generator.sampling_params.top_p=1.0 \
generator.sampling_params.stop='["</search>", "</answer>"]' \

# - Environment: environment class, max env workers, search env settings
environment.env_class="search" \
Expand All @@ -112,12 +113,23 @@ Let's walk through configuration for running GRPO to train a 4-turn search agent
trainer.eval_batch_size=256 \
trainer.eval_before_train=false \
generator.eval_sampling_params.temperature=0 \
generator.eval_sampling_params.stop='["</search>", "</answer>"]' \
trainer.eval_interval=50 \
... # logging + checkpointing configuration (see `examples/search/run_search.sh` for the full script)

To change the number of turns, you can simply change the ``generator.max_turns`` setting.
For more details on environment implementation, see :skyrl_gym_link:`skyrl_gym/envs/search/env.py`.

Note we add ``stop='["</search>", "</answer>"]'`` for both generation and evaluation sampling parameters
to adhere to the Search-R1 recipe.

If you are using ``generator.use_conversation_multi_turn=true``,
you might want to append an EOS token ID to the end of the response after these stop strings to adhere
to the model's behavior (i.e. ending generation with an EOS token ID rather than say ``</answer>``).
This can be done by setting ``generator.append_eos_token_after_stop_str_in_multi_turn=true`` in the generator config.
The full script is available in `examples/search/run_search_conversation_format.sh`.


Launching Your Training Run
---------------------------

Expand Down
3 changes: 2 additions & 1 deletion skyrl-train/examples/search/run_search.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
generator.use_conversation_multi_turn=false \
generator.n_samples_per_prompt=5 \
generator.max_turns=4 \
generator.use_conversation_multi_turn=false \
generator.sampling_params.temperature=1.0 \
generator.sampling_params.top_p=1.0 \
generator.sampling_params.stop='["</search>", "</answer>"]' \
environment.env_class="search" \
environment.skyrl_gym.max_env_workers=16 \
environment.skyrl_gym.search.log_requests=false \
Expand All @@ -64,6 +64,7 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
trainer.eval_batch_size=256 \
trainer.eval_before_train=false \
generator.eval_sampling_params.temperature=0 \
generator.eval_sampling_params.stop='["</search>", "</answer>"]' \
trainer.export_path="$HOME/skyrl-search_4turns_maxgeneratelen_500/exports" \
trainer.eval_interval=50 \
$@
Expand Down
79 changes: 79 additions & 0 deletions skyrl-train/examples/search/run_search_conversation_format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
set -x

# The exact same script as `run_search.sh` but with `use_conversation_multi_turn=true`
# and hence `append_eos_token_after_stop_str_in_multi_turn=true`
# See https://skyrl.readthedocs.io/en/latest/tutorials/skyrl_gym_generator.html on the
# difference between the two options. You might want to change the data generation prompt
# to let the model know that we are doing multi-turn conversations (i.e. user will provide
# the search result for each turn).

# Colocated GRPO training+generation for Qwen2.5-Coder-3B-Instruct on SearchR1 data.
# follow the instructions in examples/search/README.md for setting up the dataset
# and for starting the local search server
# export WANDB_API_KEY=<your_key_here>
# bash examples/search/run_search.sh

# path for dataset (.parquet files) containing the prompts and metadata for each question
DATA_DIR="$HOME/data/searchR1"

uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
data.train_data="['${DATA_DIR}/train.parquet']" \
data.val_data="['${DATA_DIR}/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.policy.optimizer_config.max_grad_norm=0.5 \
trainer.policy.optimizer_config.num_warmup_steps=94 \
trainer.algorithm.use_kl_loss=true \
trainer.algorithm.kl_loss_coef=0.001 \
trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.policy.fsdp_config.cpu_offload=false \
trainer.ref.fsdp_config.cpu_offload=true \
trainer.placement.policy_num_gpus_per_node=8 \
trainer.placement.ref_num_gpus_per_node=8 \
generator.num_inference_engines=4 \
generator.inference_engine_tensor_parallel_size=2 \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.gpu_memory_utilization=0.5 \
trainer.epochs=1 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=512 \
trainer.policy_mini_batch_size=256 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.max_prompt_length=2048 \
generator.max_input_length=4096 \
generator.sampling_params.max_generate_length=500 \
generator.async_engine=true \
generator.batched=false \
generator.use_conversation_multi_turn=true \
generator.n_samples_per_prompt=5 \
generator.max_turns=4 \
generator.sampling_params.temperature=1.0 \
generator.sampling_params.top_p=1.0 \
generator.sampling_params.stop='["</search>", "</answer>"]' \
generator.append_eos_token_after_stop_str_in_multi_turn=true \
environment.env_class="search" \
environment.skyrl_gym.max_env_workers=16 \
environment.skyrl_gym.search.log_requests=false \
environment.skyrl_gym.search.search_url="http://127.0.0.1:8000/retrieve" \
environment.skyrl_gym.search.topk=3 \
trainer.logger="wandb" \
trainer.project_name="skyrl-search" \
trainer.run_name="skyrl-search_4turns_maxgeneratelen_500" \
trainer.ckpt_interval=20 \
trainer.hf_save_interval=100 \
trainer.max_ckpts_to_keep=5 \
trainer.resume_mode=latest \
trainer.ckpt_path="$HOME/skyrl-search_4turns_maxgeneratelen_500" \
trainer.eval_batch_size=256 \
trainer.eval_before_train=false \
generator.eval_sampling_params.temperature=0 \
generator.eval_sampling_params.stop='["</search>", "</answer>"]' \
trainer.export_path="$HOME/skyrl-search_4turns_maxgeneratelen_500/exports" \
trainer.eval_interval=50 \
$@

3 changes: 2 additions & 1 deletion skyrl-train/examples/text_to_sql/run_skyrl_sql.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.7 \
generator.max_turns=6 \
generator.use_conversation_multi_turn=false \
generator.sampling_params.temperature=0.6 \
generator.sampling_params.top_p=0.95 \
generator.sampling_params.stop='["</sql>", "</solution>"]' \
generator.eval_sampling_params.stop='["</sql>", "</solution>"]' \
environment.skyrl_gym.text2sql.db_path=$DB_PATH \
trainer.logger="wandb" \
trainer.project_name="skyrlsql" \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
set -x

# The exact same script as `run_search.sh` but with `use_conversation_multi_turn=true`
# and hence `append_eos_token_after_stop_str_in_multi_turn=true`
# See https://skyrl.readthedocs.io/en/latest/tutorials/skyrl_gym_generator.html on what behavior
# use_conversation_multi_turn corresponds to. You might want to change the data generation prompt
# to let the model know that we are doing multi-turn conversations (i.e. user will provide
# the search result for each turn).

# Colocated GRPO training+generation for Qwen2.5-Coder-7B-Instruct on SkyRL-SQL-653 data.
# Uses 1 node with 8 GPUs.
# huggingface-cli download NovaSky-AI/SkyRL-SQL-653-data-newfmt --local-dir $HOME/data/sql --repo-type dataset
# export WANDB_API_KEY=<your_key_here>
# bash examples/text_to_sql/run_skyrl_sql.sh

# change these paths to your own
DATA_DIR="$HOME/data/sql"
DB_PATH="$HOME/data/sql/db_files/data"
CKPT_PATH="$HOME/ckpts/skyrl_sql_7B_ckpt"

NUM_GPUS=8
NUM_INFERENCE_ENGINES=2
TP_SIZE=4
MAX_INPUT_LENGTH=29000
MAX_GENERATE_LENGTH=3000
TRAIN_BATCH_SIZE=256

uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
trainer.algorithm.advantage_estimator="grpo" \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.policy.model.path="Qwen/Qwen2.5-Coder-7B-Instruct" \
trainer.epochs=30 \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.policy.fsdp_config.cpu_offload=false \
trainer.ref.fsdp_config.cpu_offload=true \
trainer.policy.optimizer_config.max_grad_norm=0.5 \
trainer.policy.sequence_parallel_size=1 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine_tensor_parallel_size=$TP_SIZE \
trainer.train_batch_size=$TRAIN_BATCH_SIZE \
trainer.micro_forward_batch_size_per_gpu=8 \
trainer.micro_train_batch_size_per_gpu=1 \
trainer.max_prompt_length=6000 \
generator.max_input_length=$MAX_INPUT_LENGTH \
generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.policy_mini_batch_size=256 \
trainer.algorithm.use_kl_loss=false \
trainer.ckpt_interval=60 \
trainer.hf_save_interval=30 \
trainer.dump_data_batch=true \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=false \
environment.env_class=text2sql \
generator.use_conversation_multi_turn=true \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.7 \
generator.max_turns=6 \
generator.sampling_params.temperature=0.6 \
generator.sampling_params.top_p=0.95 \
generator.sampling_params.stop='["</sql>", "</solution>"]' \
generator.append_eos_token_after_stop_str_in_multi_turn=true \
generator.eval_sampling_params.stop='["</sql>", "</solution>"]' \
environment.skyrl_gym.text2sql.db_path=$DB_PATH \
trainer.logger="wandb" \
trainer.project_name="skyrlsql" \
trainer.run_name="skyrlsql_repro" \
trainer.resume_mode=latest \
trainer.ckpt_path=$CKPT_PATH \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
$@
3 changes: 2 additions & 1 deletion skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ uv run --isolated --frozen --extra vllm --extra deepspeed -m skyrl_train.entrypo
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.7 \
generator.max_turns=5 \
generator.use_conversation_multi_turn=false \
generator.sampling_params.temperature=0.6 \
generator.sampling_params.top_p=0.95 \
generator.sampling_params.stop='["</sql>", "</solution>"]' \
generator.eval_sampling_params.stop='["</sql>", "</solution>"]' \
environment.skyrl_gym.text2sql.db_path=$DB_PATH \
trainer.logger="wandb" \
trainer.project_name="skyrlsql" \
Expand Down
Loading