diff --git a/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py b/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py index 3c1614c546..d0d0fb7abb 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py @@ -19,6 +19,9 @@ from typing import Any, Dict import hydra from skyrl_train.entrypoints.main_base import config_dir +from loguru import logger + +OBSERVATION_PROMPT = "give me another solution" def get_test_actor_config() -> DictConfig: @@ -43,7 +46,7 @@ def step(self, action: str): self.turns += 1 done = self.turns >= self.max_turns return BaseTextEnvStepOutput( - observations=[{"role": "user", "content": f"turn {self.turns}"}] if not done else [], + observations=[{"role": "user", "content": f"{OBSERVATION_PROMPT} {self.turns}"}] if not done else [], reward=0, done=done, metadata={}, @@ -290,9 +293,9 @@ async def test_generator_formatting_use_conversation_multi_turn(model_name): num_inference_engines=1, tensor_parallel_size=1, model=model_name, - max_prompt_length=1000, - max_input_length=3000, - max_generate_length=1000, + max_prompt_length=3000, + max_input_length=10000, + max_generate_length=3000, env_class="test_env", num_prompts=2, max_turns=3, @@ -302,25 +305,34 @@ async def test_generator_formatting_use_conversation_multi_turn(model_name): for i, resp_ids in enumerate(generator_output["response_ids"]): loss_mask = generator_output["loss_masks"][i] prompt_token_ids = generator_output["prompt_token_ids"][i] + stop_reason = generator_output["stop_reasons"][i] masked_out_resp_ids = [resp_ids[j] for j in range(len(resp_ids)) if loss_mask[j] == 0] masked_in_resp_ids = [resp_ids[j] for j in range(len(resp_ids)) if loss_mask[j] == 1] masked_out_resp_str = tokenizer.decode(masked_out_resp_ids) masked_in_resp_str = tokenizer.decode(masked_in_resp_ids) - assert "turn 1" in masked_out_resp_str, "turn 1 observation should be loss masked out" - assert "turn 2" in masked_out_resp_str, "turn 2 observation should be loss masked out" assert ( MODEL_TO_GENERATION_PROMPT[model_name] in masked_out_resp_str and MODEL_TO_GENERATION_PROMPT[model_name] not in masked_in_resp_str ), "generation prompts should be loss masked out" - # count number of eos tokens in masked_in_resp_ids - # NOTE: this could fail if the stop reason is "length" where model fails to generate eos - assert ( - sum(1 for _ in masked_in_resp_ids if _ == tokenizer.eos_token_id) == 3 - ) # 1 eos for each assistant response - assert sum(1 for _ in resp_ids if _ == tokenizer.eos_token_id) == 5 # 2 user eos, 3 assistant eos + # Observations and EOS expectations only strictly apply when the model finished turns + if stop_reason == "stop": + assert ( + f"{OBSERVATION_PROMPT} 1" in masked_out_resp_str + ), f'"{OBSERVATION_PROMPT} 1" observation should be loss masked out' + assert ( + f"{OBSERVATION_PROMPT} 2" in masked_out_resp_str + ), f'"{OBSERVATION_PROMPT} 2" observation should be loss masked out' + # count number of eos tokens in masked_in_resp_ids: 1 eos per assistant response (3 turns) + assert sum(1 for _ in masked_in_resp_ids if _ == tokenizer.eos_token_id) == 3 + # total eos in full response: 2 user eos + 3 assistant eos + assert sum(1 for _ in resp_ids if _ == tokenizer.eos_token_id) == 5 + else: + # On length stops, the model may not produce EOS at the end of each assistant turn. + # Only check that generation prompts are masked out. + logger.warning(f"Got stop reason {stop_reason}, so we did not fully check the response") if model_name == "Qwen/Qwen3-0.6B": assert ( sum(1 for _ in prompt_token_ids if _ == tokenizer.eos_token_id) == 1 @@ -349,7 +361,7 @@ async def test_generator_formatting_no_use_conversation_multi_turn(model_name): num_inference_engines=1, tensor_parallel_size=1, model=model_name, - max_prompt_length=1000, + max_prompt_length=3000, max_input_length=10000, max_generate_length=3000, env_class="test_env", @@ -369,8 +381,12 @@ async def test_generator_formatting_no_use_conversation_multi_turn(model_name): masked_out_resp_str = tokenizer.decode(masked_out_resp_ids) masked_in_resp_str = tokenizer.decode(masked_in_resp_ids) - assert "turn 1" in masked_out_resp_str, "turn 1 observation should be loss masked out" - assert "turn 2" in masked_out_resp_str, "turn 2 observation should be loss masked out" + assert ( + f"{OBSERVATION_PROMPT} 1" in masked_out_resp_str + ), f'"{OBSERVATION_PROMPT} 1" observation should be loss masked out' + assert ( + f"{OBSERVATION_PROMPT} 2" in masked_out_resp_str + ), f'"{OBSERVATION_PROMPT} 2" observation should be loss masked out' assert ( prompt_str.count(MODEL_TO_GENERATION_PROMPT[model_name]) + resp_str.count(MODEL_TO_GENERATION_PROMPT[model_name])