Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
skyrl-train/tests/cpu/generators/test_skyrl_gym_generator_chat_templating.py::test_skyrl_gym_generator_chat_templating_exact
"""

from datetime import date


# Produced by expected_str = tokenizer.apply_chat_template(expected_chat_history, tokenize=False)
# where expected_chat_history is:
Expand Down Expand Up @@ -33,10 +35,10 @@ def get_expected_chat_history(mock_response_text: str):
b<|im_end|>
"""

LLAMA3_2_EXPECTED_STR = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
LLAMA3_2_EXPECTED_STR = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 10 Oct 2025
Today Date: {date.today().strftime("%d %b %Y")}

<|eot_id|><|start_header_id|>user<|end_header_id|>

Expand Down
12 changes: 10 additions & 2 deletions skyrl-train/tests/gpu/gpu_ci/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,18 @@ async def test_generator_formatting_use_conversation_multi_turn(model_name):
assert (
f"{OBSERVATION_PROMPT} 2" in masked_out_resp_str
), f'"{OBSERVATION_PROMPT} 2" observation should be loss masked out'
# TODO(Charlie): add more rigorous tests that is robust to stop_reason being length.
# Either make GeneratorOutput return stop reason for each turn, or change the way we manage
# max generation length.
num_resp_eos = sum(1 for _ in masked_in_resp_ids if _ == tokenizer.eos_token_id)
num_total_eos = sum(1 for _ in resp_ids if _ == tokenizer.eos_token_id)
common_msg = "Could be due to stop_reason is length in some of the turns."
# count number of eos tokens in masked_in_resp_ids: 1 eos per assistant response (3 turns)
assert sum(1 for _ in masked_in_resp_ids if _ == tokenizer.eos_token_id) == 3
if num_resp_eos != 3:
logger.warning(f"Got {num_resp_eos} eos tokens in masked_in_resp_ids, expected 3. {common_msg}")
# total eos in full response: 2 user eos + 3 assistant eos
assert sum(1 for _ in resp_ids if _ == tokenizer.eos_token_id) == 5
if num_total_eos != 5:
logger.warning(f"Got {num_total_eos} eos tokens in resp_ids, expected 5. {common_msg}")
else:
# On length stops, the model may not produce EOS at the end of each assistant turn.
# Only check that generation prompts are masked out.
Expand Down