diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py index b82738c88f..735bae28f5 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py @@ -5,71 +5,47 @@ import pytest from typing import List, Dict, Any from unittest.mock import AsyncMock, MagicMock, patch +from transformers import AutoTokenizer from skyrl_train.generators.skyrl_gym_generator import SkyRLGymGenerator from skyrl_train.generators.base import GeneratorInput, GeneratorOutput, ConversationType from skyrl_train.generators.utils import concatenate_generator_outputs, get_metrics_from_generator_output from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput -# Mock constants, where 4 is the eos token id -MOCK_LLM_OUTPUT_IDS = [1, 10, 12, 4] -MOCK_TOKENIZER_ENCODED_IDS = [1, 2, 3, 4] +# Mock constants +MOCK_LLM_OUTPUT_TEXT = "mocked output" -# TODO (erictang000): clean up the mocking for tests in this file -@pytest.fixture -def mock_tokenizer(): - """ - A mock tokenizer that encodes any non-empty string to `MOCK_TOKENIZER_ENCODED_IDS`. - For chat template, if `tokenize=False`, concatenate the content of each message. - If `tokenize=True`, return `MOCK_TOKENIZER_ENCODED_IDS` for each message. - """ - tokenizer = MagicMock() +@pytest.fixture(params=[ + "Qwen/Qwen2.5-0.5B-Instruct", + "unsloth/Llama-3.2-1B-Instruct", + "Qwen/Qwen3-0.6B", +]) +def model_name(request): + return request.param - def mock_apply_chat_template(x, **kwargs): - if not kwargs.get("tokenize", True): - return "".join([str(i["content"]) for i in x]) - else: - # Non-dict return - if isinstance(x, list) and len(x) > 0 and isinstance(x[0], list): - # Multiple prompts - return [MOCK_TOKENIZER_ENCODED_IDS.copy() for _ in x] - else: - # Single prompt or conversation - return MOCK_TOKENIZER_ENCODED_IDS.copy() - - def mock_encode(x, **kwargs): - if x != "": - return MOCK_TOKENIZER_ENCODED_IDS.copy() - else: - return [] - tokenizer.apply_chat_template.side_effect = mock_apply_chat_template - tokenizer.decode.side_effect = lambda x: "decoded_output" - tokenizer.encode.side_effect = mock_encode - tokenizer.eos_token_id = 4 - tokenizer.eos_token = "<|end_of_turn|>" - tokenizer.return_value = {"input_ids": MOCK_TOKENIZER_ENCODED_IDS.copy()} # simulate tokenized response - return tokenizer +@pytest.fixture +def tokenizer(model_name): + return AutoTokenizer.from_pretrained(model_name) @pytest.fixture -def mock_llm(): +def mock_llm(tokenizer): """ - This replaces InferenceEngineClient, where `.generate()` always returns MOCK_LLM_OUTPUT_IDS - for each prompt, with corresponding string output "mocked output". + Mock InferenceEngineClient generate() using the real tokenizer. """ mock = MagicMock() - # Mock the new generate method def mock_generate(input_batch): num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) + text = MOCK_LLM_OUTPUT_TEXT + (tokenizer.eos_token or "") + token_ids = tokenizer.encode(text, add_special_tokens=False) return { - "responses": ["mocked output"] * num_prompts, + "responses": [MOCK_LLM_OUTPUT_TEXT] * num_prompts, "stop_reasons": ["stop"] * num_prompts, - # say response gets tokenized to 3 tokens - "response_logprobs": [[0.1] * len(MOCK_LLM_OUTPUT_IDS)] * num_prompts, - "response_ids": [MOCK_LLM_OUTPUT_IDS.copy()] * num_prompts, + "response_logprobs": [[0.1] * len(token_ids)] * num_prompts, + "response_ids": [token_ids.copy()] * num_prompts, } mock.generate = AsyncMock(side_effect=mock_generate) @@ -225,7 +201,7 @@ def validate_generator_output(output: GeneratorOutput) -> bool: @patch("skyrl_gym.make") @pytest.mark.parametrize("use_conversation_multi_turn", [True, False]) async def test_agent_loop_single_turn( - mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, use_conversation_multi_turn, mock_env_cfg + mock_make, tokenizer, mock_llm, mock_env, mock_generator_cfg, use_conversation_multi_turn, mock_env_cfg, model_name ): """ This test mocks when we call SkyRLGymGenerator.agent_loop() despite being a single-turn generation. @@ -233,8 +209,6 @@ async def test_agent_loop_single_turn( """ mock_generator_cfg.use_conversation_multi_turn = use_conversation_multi_turn mock_env.step.side_effect = lambda x: BaseTextEnvStepOutput(observations=[], reward=1.0, done=True, metadata={}) - mock_tokenizer.eos_token_id = 4 # bypass check for eos token id for this test - mock_make.return_value = mock_env mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) @@ -242,8 +216,8 @@ async def test_agent_loop_single_turn( generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, - model_name="test_model", + tokenizer=tokenizer, + model_name=model_name, ) generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly @@ -253,15 +227,20 @@ async def test_agent_loop_single_turn( prompt, mock_env_cfg.env_class, extras, max_tokens=8, max_input_length=512 ) - assert response_ids == MOCK_LLM_OUTPUT_IDS + assert isinstance(response_ids, list) and len(response_ids) > 0 assert reward == 1.0 assert stop_reason == "stop" - assert loss_mask == [1] * len(MOCK_LLM_OUTPUT_IDS) + if ("Qwen3" in model_name) and use_conversation_multi_turn: + # With Qwen3 retokenization, assistant masks include zeros for generation prompts + assert len(loss_mask) == len(response_ids) + assert sum(loss_mask) >= 1 + else: + assert loss_mask == [1] * len(response_ids) @pytest.mark.asyncio @patch("skyrl_gym.make") -async def test_generate_batched(mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg): +async def test_generate_batched(mock_make, tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg, model_name): mock_make.return_value = mock_env mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) @@ -269,8 +248,8 @@ async def test_generate_batched(mock_make, mock_tokenizer, mock_llm, mock_env, m generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, - model_name="test_model", + tokenizer=tokenizer, + model_name=model_name, ) generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly @@ -286,11 +265,11 @@ async def test_generate_batched(mock_make, mock_tokenizer, mock_llm, mock_env, m generator_output: GeneratorOutput = await generator.generate(input_batch) # uses output from llm directly - assert generator_output["response_ids"][0] == MOCK_LLM_OUTPUT_IDS + assert isinstance(generator_output["response_ids"][0], list) and len(generator_output["response_ids"][0]) > 0 assert generator_output["rewards"][0] == 1.0 assert generator_output["stop_reasons"][0] == "stop" - assert generator_output["loss_masks"][0] == [1] * len(MOCK_LLM_OUTPUT_IDS) + assert generator_output["loss_masks"][0] == [1] * len(generator_output["response_ids"][0]) def test_generator_output_concatenation(): @@ -358,7 +337,7 @@ def test_get_metrics_from_generator_output(): @pytest.mark.parametrize("batched", [True, False]) @patch("skyrl_gym.make") async def test_generate_interface_compliance( - mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg, batched + mock_make, tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg, batched, model_name ): """Test that SkyRLGymGenerator.generate() strictly conforms to the TypedDict interface. @@ -373,8 +352,8 @@ async def test_generate_interface_compliance( generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, - model_name="test_model", + tokenizer=tokenizer, + model_name=model_name, ) generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly @@ -438,7 +417,7 @@ async def test_generate_interface_compliance( @pytest.mark.parametrize("turns_to_exceed", [1, 3]) # Test single-turn and multi-turn scenarios @patch("skyrl_gym.make") async def test_length_limit_exceeded_during_conversation( - mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg, turns_to_exceed + mock_make, tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg, turns_to_exceed ): """Test that length limit is enforced during multi-turn conversations. @@ -460,48 +439,43 @@ def mock_step_never_done(output): metadata={}, ) - # We start with initial prompt len 4 due to mock_apply_chat_template - # Each turn, observation is 4 tokens due to mock_encode mock_env.step.side_effect = mock_step_never_done - max_input_length = 20 # Low limit to trigger length exceeded - - # Mock the new generate method - def mock_generate(input_batch): - num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) - if turns_to_exceed == 1: - mock_llm_output_ids = [1] * 20 # Enough to exceed limit immediately (4 + 20 + 4 = 28 > 20) - assert ( - len(MOCK_TOKENIZER_ENCODED_IDS) + len(mock_llm_output_ids) + len(MOCK_TOKENIZER_ENCODED_IDS) - > max_input_length - ) - else: - assert turns_to_exceed == 3 - mock_llm_output_ids = [1] * 2 # Enough to exceed limit after 3 turns (4 + (2 + 4) * 3 = 22 > 20) - assert ( - len(MOCK_TOKENIZER_ENCODED_IDS) - + (len(mock_llm_output_ids) + len(MOCK_TOKENIZER_ENCODED_IDS)) * turns_to_exceed - > max_input_length - ) - return { - "responses": ["mocked output"] * num_prompts, - "stop_reasons": ["stop"] * num_prompts, - "response_logprobs": [[0.1] * len(mock_llm_output_ids)] * num_prompts, - "response_ids": [mock_llm_output_ids.copy()] * num_prompts, - } - - mock_llm.generate = AsyncMock(side_effect=mock_generate) generator = SkyRLGymGenerator( generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, + tokenizer=tokenizer, model_name="test_model", ) - generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly - + # compute lengths using real tokenizer prompt = [{"role": "user", "content": "Start conversation"}] extras = {"test": "value"} + initial_prompt_len = len(tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=True)) + # observation token length per turn + obs_ids = tokenizer.apply_chat_template( + [*generator.base_conversation, {"role": "user", "content": "next"}], + add_generation_prompt=True, + tokenize=True, + )[len(generator.base_conversation_token_ids):] + observation_len = len(obs_ids) + per_turn_out_len = 1 + + # choose a max_input_length that will exceed exactly after `turns_to_exceed` turns + max_input_length = initial_prompt_len + turns_to_exceed * (per_turn_out_len + observation_len) - 1 + + # Mock the generate to output fixed number of tokens per turn + def mock_generate_len(input_batch): + num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) + out_ids = [1] * per_turn_out_len + return { + "responses": [MOCK_LLM_OUTPUT_TEXT] * num_prompts, + "stop_reasons": ["stop"] * num_prompts, + "response_logprobs": [[0.1] * len(out_ids)] * num_prompts, + "response_ids": [out_ids.copy()] * num_prompts, + } + + mock_llm.generate = AsyncMock(side_effect=mock_generate_len) response_ids, reward, stop_reason, loss_mask, prompt_token_ids, rollout_logprobs = await generator.agent_loop( prompt, "test_env", extras, max_tokens=100, max_input_length=max_input_length @@ -525,95 +499,73 @@ def mock_generate(input_batch): @pytest.mark.asyncio @patch("skyrl_gym.make") async def test_multi_turn_response_truncation( - mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg + mock_make, tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg ): - """ - Tests that in a multi-turn conversation, if the final tokenized response exceeds the - calculated maximum length, it is correctly truncated and the stop reason is set to 'length'. - """ + """Ensure multi-turn conversation truncates and sets stop_reason to 'length'.""" mock_make.return_value = mock_env - mock_generator_cfg.max_turns = 3 # Ensure multi-turn logic is triggered - mock_generator_cfg.batched = False # Test is for agent_loop + mock_generator_cfg.max_turns = 3 + mock_generator_cfg.batched = False mock_generator_cfg.use_conversation_multi_turn = True mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) - # Configure environment to run for multiple turns to generate enough tokens for truncation step_count = 0 - def mock_step_multi_turn(output): + def mock_step_multi_turn(_): nonlocal step_count step_count += 1 - done = step_count >= 10 # Allow many turns to exceed length limit + done = step_count >= 10 return BaseTextEnvStepOutput( observations=[{"role": "user", "content": "next turn"}], reward=0.5, done=done, metadata={} ) mock_env.step.side_effect = mock_step_multi_turn - # Define token lengths to control the test - initial_prompt_len = 13 - max_tokens_from_llm = 20 - max_input_len = 50 - - # Expected max response tokens = max_tokens + max_input_length - initial_prompt_length - expected_max_response_tokens = max_tokens_from_llm + max_input_len - initial_prompt_len # 20 + 50 - 13 = 57 - - def mock_apply_chat_template(messages, **kwargs): - if kwargs.get("tokenize", True): - # Return initial prompt tokens - return [1] * initial_prompt_len - else: - # Not used in messages_mode=False - return "".join([msg.get("content", "") for msg in messages]) - - def mock_encode(text, **kwargs): - # This makes observation_ids to always be 13 tokens - return [1] * 13 - - mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template - mock_tokenizer.encode.side_effect = mock_encode - - # The intitial prompt is 13 tokens due to mock_apply_chat_template - # Each turn, observation is 13 tokens due to mock_encode and empty system_prompt_ids - # And the LLM response is 4 tokens due to MOCK_LLM_OUTPUT_IDS - # So input_ids are 13, 30, 47, 64. And 64 would cause a break in the loop due to exceeding max_input_len. - # Then with 64, we get the `input_ids[initial_prompt_length:]`, which makes our final - # response_ids to be 64 - 13 = 51 tokens. So in this case, we are not truncated by expected_max_response_tokens. - expected_final_response_tokens = 51 - generator = SkyRLGymGenerator( generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, + tokenizer=tokenizer, model_name="test_model", ) - generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly + generator.base_conversation_token_ids = [] prompt = [{"role": "user", "content": "Initial prompt"}] extras = {} + # Compute limits dynamically + init_len = len(tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=True)) + per_turn_ids = [1, 1] + + def mock_generate_len(input_batch): + num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) + return { + "responses": ["abc"] * num_prompts, + "stop_reasons": ["stop"] * num_prompts, + "response_logprobs": [[0.1] * len(per_turn_ids)] * num_prompts, + "response_ids": [per_turn_ids.copy()] * num_prompts, + } + + mock_llm.generate = AsyncMock(side_effect=mock_generate_len) + + obs_ids = tokenizer.apply_chat_template( + [*generator.base_conversation, {"role": "user", "content": "next"}], + add_generation_prompt=True, + tokenize=True, + )[len(generator.base_conversation_token_ids):] + max_input_len = init_len + 2 * (len(obs_ids) + len(per_turn_ids)) + 1 + response_ids, _, stop_reason, loss_mask, _, _ = await generator.agent_loop( - prompt, "test_env", extras, max_tokens=max_tokens_from_llm, max_input_length=max_input_len + prompt, "test_env", extras, max_tokens=20, max_input_length=max_input_len ) - # Verify truncation occurred - assert len(response_ids) <= expected_max_response_tokens - assert ( - len(response_ids) == expected_final_response_tokens - ), f"Expected {expected_final_response_tokens} response tokens, got {len(response_ids)}" - assert ( - len(loss_mask) == expected_final_response_tokens - ), f"Expected {expected_final_response_tokens} loss mask entries, got {len(loss_mask)}" - - # Verify stop reason is "length" due to truncation - assert stop_reason == "length", f"Expected stop_reason='length', got '{stop_reason}'" + assert len(loss_mask) == len(response_ids) + assert stop_reason == "length" @pytest.mark.asyncio @patch("skyrl_gym.make") async def test_postprocessed_action_used( - mock_make, mock_tokenizer, mock_llm, mock_env, mock_env_cfg, mock_generator_cfg + mock_make, tokenizer, mock_llm, mock_env, mock_env_cfg, mock_generator_cfg, model_name ): """ Tests that if the environment returns a `postprocessed_action`, it is used @@ -643,32 +595,15 @@ def mock_step(_): mock_llm.generate.return_value = { "responses": [llm_raw_response], "stop_reasons": ["stop"], + "response_ids": [tokenizer.encode(llm_raw_response, add_special_tokens=False)], } - def mock_apply_chat_template(messages, **kwargs): - if kwargs.get("tokenize", True): - return [1] * 5 # Initial prompt tokens - else: - return "".join([msg.get("content", "") for msg in messages]) - - def mock_encode(text, **kwargs): - # The key test: postprocessed response should be encoded, not raw LLM output - if postprocessed_response in str(text): - return [42] * 10 # Distinctive tokens for postprocessed response - elif "new input" in str(text): - return [5] * 2 # Observation tokens - else: - return [1] * 3 # Default tokens - - mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template - mock_tokenizer.encode.side_effect = mock_encode - generator = SkyRLGymGenerator( generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, - model_name="test_model", + tokenizer=tokenizer, + model_name=model_name, ) generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly @@ -676,14 +611,13 @@ def mock_encode(text, **kwargs): env_extras = {} response_ids, reward, stop_reason, loss_mask, prompt_ids, _ = await generator.agent_loop( - prompt, "test_env", env_extras, max_tokens=20, max_input_length=50 + prompt, "test_env", env_extras, max_tokens=1000, max_input_length=2000 ) - # Check that the postprocessed response tokens (42) are present in response_ids - # This verifies that postprocessed_action was used instead of raw LLM output - assert any(token == 42 for token in response_ids), f"Expected postprocessed response tokens (42) in {response_ids}" - # Make sure raw LLM tokens (99) are NOT present - assert not any(token == 99 for token in response_ids), f"Raw LLM output tokens (99) should not be in {response_ids}" + # Verify using postprocessed response + decoded_response = tokenizer.decode(response_ids) + assert postprocessed_response in decoded_response + assert llm_raw_response not in decoded_response assert reward == 1.0 assert stop_reason == "stop" @@ -693,7 +627,7 @@ def mock_encode(text, **kwargs): @pytest.mark.asyncio @patch("skyrl_gym.make") async def test_apply_overlong_filtering_non_batched( - mock_make, mock_tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg + mock_make, tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg, model_name ): """ Test that apply_overlong_filtering correctly zeroes out loss masks for truncated trajectories @@ -713,21 +647,12 @@ async def test_apply_overlong_filtering_non_batched( # Mock out the environment and inference engine generation. mock_env.step.side_effect = lambda x: BaseTextEnvStepOutput(observations=[], reward=1.0, done=True, metadata={}) - def mock_apply_chat_template(messages, **kwargs): - if kwargs.get("tokenize", True): - return [1, 2, 3, 4, 5] # 5 tokens for prompt - else: - return "".join([msg.get("content", "") for msg in messages]) - - mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template - mock_tokenizer.eos_token_id = 4 # Set EOS token ID - generator = SkyRLGymGenerator( generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, - model_name="test_model", + tokenizer=tokenizer, + model_name=model_name, ) generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly @@ -736,7 +661,7 @@ def mock_apply_chat_template(messages, **kwargs): return_value={ "responses": ["truncated response"], "stop_reasons": ["length"], - "response_ids": [[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]], # 10 tokens, will be truncated + "response_ids": [[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]], } ) @@ -758,7 +683,6 @@ def mock_apply_chat_template(messages, **kwargs): 0, 0, ], "Loss mask should be all zeros for response not ending with eos token" - # Note: The long response gets truncated by max_response_tokens, so it doesn't end with eos token # Second test: response that ends with eos token (should not be filtered) # Reset the environment init to ensure clean state @@ -767,7 +691,7 @@ def mock_apply_chat_template(messages, **kwargs): return_value={ "responses": ["truncated response"], "stop_reasons": ["length"], - "response_ids": [[20, 21, 4]], # 3 tokens, ends with eos token 4 + "response_ids": [[20, 21, tokenizer.eos_token_id]], } ) @@ -793,11 +717,12 @@ def mock_apply_chat_template(messages, **kwargs): @patch("skyrl_gym.make") async def test_apply_overlong_filtering_batched( mock_make, - mock_tokenizer, + tokenizer, mock_llm, mock_env, mock_generator_cfg, mock_env_cfg, + model_name, ): """ Test that apply_overlong_filtering correctly zeroes out loss masks for truncated trajectories @@ -828,18 +753,16 @@ def mock_apply_chat_template(messages, **kwargs): return "".join([msg.get("content", "") for msg in messages]) def mock_encode_or_tokenize(text): - return [10, 11, 12, 13] # 4 tokens, doesn't end with eos_token_id=4 + return [10, 11, 12, 13] # 4 tokens - mock_tokenizer.apply_chat_template.side_effect = mock_apply_chat_template - mock_tokenizer.side_effect = lambda text: {"input_ids": mock_encode_or_tokenize(text)} - mock_tokenizer.eos_token_id = 4 # Set EOS token ID + # Note: we no longer bind to a mock tokenizer here; we use the real tokenizer for ids elsewhere. generator = SkyRLGymGenerator( generator_cfg=mock_generator_cfg, skyrl_gym_cfg=mock_env_cfg, inference_engine_client=mock_llm, - tokenizer=mock_tokenizer, - model_name="test_model", + tokenizer=tokenizer, + model_name=model_name, ) generator.base_conversation_token_ids = [] # to make sure observation_ids are encoded correctly diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py index bfc1262f1f..6bc56f0718 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py @@ -182,4 +182,74 @@ def mock_generate(input_batch): if "Qwen" in model_name: expected_loss_masks = expected_loss_masks[:-1] # remove the extra 0 for \n assert len(expected_loss_masks) == len(generator_output["loss_masks"][0]) - assert generator_output["loss_masks"][0] == expected_loss_masks + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", ["Qwen/Qwen2.5-0.5B-Instruct", "unsloth/Llama-3.2-1B-Instruct", "Qwen/Qwen3-0.6B"] +) +async def test_skyrl_gym_generator_single_turn_chat_templating(model_name): + _register_test_env_if_needed() + tokenizer = AutoTokenizer.from_pretrained(model_name) + mock_llm = MagicMock() + + def mock_generate(input_batch): + num_prompts = len(input_batch["prompts"]) if "prompts" in input_batch else len(input_batch["prompt_token_ids"]) + output_text = "b" + (tokenizer.eos_token or "") + output_ids = tokenizer.encode(output_text, add_special_tokens=False) + return { + "responses": ["b"] * num_prompts, + "stop_reasons": ["stop"] * num_prompts, + "response_logprobs": None, + "response_ids": [output_ids.copy()] * num_prompts, + } + + mock_llm.generate = AsyncMock(side_effect=mock_generate) + + generator_cfg = DictConfig( + { + "sampling_params": {"max_generate_length": 200, "logprobs": None}, + "max_input_length": 200, + "batched": False, + "max_turns": 3, + "zero_reward_on_non_stop": False, + "apply_overlong_filtering": False, + "use_conversation_multi_turn": False, + } + ) + env_cfg = DictConfig({"max_env_workers": 0, "env_class": "cpu_test_env"}) + + generator = SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=env_cfg, + inference_engine_client=mock_llm, + tokenizer=tokenizer, + model_name=model_name, + ) + + input_batch: GeneratorInput = { + "prompts": [[{"role": "user", "content": "a"}]], + "env_extras": [{"answer": "4"}], + "env_classes": [env_cfg.env_class], + } + generator_output: GeneratorOutput = await generator.generate(input_batch) + + tokens_b = tokenizer.encode("b", add_special_tokens=False) + tokens_1 = tokenizer.encode("1", add_special_tokens=False) + tokens_2 = tokenizer.encode("2", add_special_tokens=False) + expected_len = len(tokens_b) + len(tokens_1) + len(tokens_b) + len(tokens_2) + len(tokens_b) + 1 + resp_ids = generator_output["response_ids"][0] + loss_mask = generator_output["loss_masks"][0] + + assert len(resp_ids) == expected_len + assert resp_ids[-1] == tokenizer.eos_token_id + + expected_loss_mask = ( + [1] * len(tokens_b) + + [0] * len(tokens_1) + + [1] * len(tokens_b) + + [0] * len(tokens_2) + + [1] * len(tokens_b) + + [1] + ) + assert loss_mask == expected_loss_mask