Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from typing import Any, Dict
import hydra
from skyrl_train.entrypoints.main_base import config_dir
from loguru import logger

OBSERVATION_PROMPT = "give me another solution"


def get_test_actor_config() -> DictConfig:
Expand All @@ -43,7 +46,7 @@ def step(self, action: str):
self.turns += 1
done = self.turns >= self.max_turns
return BaseTextEnvStepOutput(
observations=[{"role": "user", "content": f"turn {self.turns}"}] if not done else [],
observations=[{"role": "user", "content": f"{OBSERVATION_PROMPT} {self.turns}"}] if not done else [],
reward=0,
done=done,
metadata={},
Expand Down Expand Up @@ -290,9 +293,9 @@ async def test_generator_formatting_use_conversation_multi_turn(model_name):
num_inference_engines=1,
tensor_parallel_size=1,
model=model_name,
max_prompt_length=1000,
max_input_length=3000,
max_generate_length=1000,
max_prompt_length=3000,
max_input_length=10000,
max_generate_length=3000,
env_class="test_env",
num_prompts=2,
max_turns=3,
Expand All @@ -302,25 +305,34 @@ async def test_generator_formatting_use_conversation_multi_turn(model_name):
for i, resp_ids in enumerate(generator_output["response_ids"]):
loss_mask = generator_output["loss_masks"][i]
prompt_token_ids = generator_output["prompt_token_ids"][i]
stop_reason = generator_output["stop_reasons"][i]
masked_out_resp_ids = [resp_ids[j] for j in range(len(resp_ids)) if loss_mask[j] == 0]
masked_in_resp_ids = [resp_ids[j] for j in range(len(resp_ids)) if loss_mask[j] == 1]

masked_out_resp_str = tokenizer.decode(masked_out_resp_ids)
masked_in_resp_str = tokenizer.decode(masked_in_resp_ids)

assert "turn 1" in masked_out_resp_str, "turn 1 observation should be loss masked out"
assert "turn 2" in masked_out_resp_str, "turn 2 observation should be loss masked out"
assert (
MODEL_TO_GENERATION_PROMPT[model_name] in masked_out_resp_str
and MODEL_TO_GENERATION_PROMPT[model_name] not in masked_in_resp_str
), "generation prompts should be loss masked out"

# count number of eos tokens in masked_in_resp_ids
# NOTE: this could fail if the stop reason is "length" where model fails to generate eos
assert (
sum(1 for _ in masked_in_resp_ids if _ == tokenizer.eos_token_id) == 3
) # 1 eos for each assistant response
assert sum(1 for _ in resp_ids if _ == tokenizer.eos_token_id) == 5 # 2 user eos, 3 assistant eos
# Observations and EOS expectations only strictly apply when the model finished turns
if stop_reason == "stop":
assert (
f"{OBSERVATION_PROMPT} 1" in masked_out_resp_str
), f'"{OBSERVATION_PROMPT} 1" observation should be loss masked out'
assert (
f"{OBSERVATION_PROMPT} 2" in masked_out_resp_str
), f'"{OBSERVATION_PROMPT} 2" observation should be loss masked out'
# count number of eos tokens in masked_in_resp_ids: 1 eos per assistant response (3 turns)
assert sum(1 for _ in masked_in_resp_ids if _ == tokenizer.eos_token_id) == 3
# total eos in full response: 2 user eos + 3 assistant eos
assert sum(1 for _ in resp_ids if _ == tokenizer.eos_token_id) == 5
else:
# On length stops, the model may not produce EOS at the end of each assistant turn.
# Only check that generation prompts are masked out.
logger.warning(f"Got stop reason {stop_reason}, so we did not fully check the response")
if model_name == "Qwen/Qwen3-0.6B":
assert (
sum(1 for _ in prompt_token_ids if _ == tokenizer.eos_token_id) == 1
Expand Down Expand Up @@ -349,7 +361,7 @@ async def test_generator_formatting_no_use_conversation_multi_turn(model_name):
num_inference_engines=1,
tensor_parallel_size=1,
model=model_name,
max_prompt_length=1000,
max_prompt_length=3000,
max_input_length=10000,
max_generate_length=3000,
env_class="test_env",
Expand All @@ -369,8 +381,12 @@ async def test_generator_formatting_no_use_conversation_multi_turn(model_name):
masked_out_resp_str = tokenizer.decode(masked_out_resp_ids)
masked_in_resp_str = tokenizer.decode(masked_in_resp_ids)

assert "turn 1" in masked_out_resp_str, "turn 1 observation should be loss masked out"
assert "turn 2" in masked_out_resp_str, "turn 2 observation should be loss masked out"
assert (
f"{OBSERVATION_PROMPT} 1" in masked_out_resp_str
), f'"{OBSERVATION_PROMPT} 1" observation should be loss masked out'
assert (
f"{OBSERVATION_PROMPT} 2" in masked_out_resp_str
), f'"{OBSERVATION_PROMPT} 2" observation should be loss masked out'
assert (
prompt_str.count(MODEL_TO_GENERATION_PROMPT[model_name])
+ resp_str.count(MODEL_TO_GENERATION_PROMPT[model_name])
Expand Down