Skip to content

Commit efe09d7

Browse files
authored
Merge branch 'main' into jiemingz/fp8_block
2 parents 0fcba54 + 38e9ef1 commit efe09d7

File tree

2 files changed

+126
-167
lines changed

2 files changed

+126
-167
lines changed

nemo_rl/data/llm_message_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,13 +548,23 @@ def _format_content_helper(
548548
message_chunk = tokenizer.bos_token + message_chunk
549549

550550
if i == len(message_log_strs) - 1:
551-
message_chunk = message_chunk.rstrip("\n")
551+
r"""
552+
This is an attempt to robustly append the eos token. The origin is Qwen
553+
chat templates always append <eos>\n and some models like gemma do not
554+
use the <eos> at all in the chat template. Adding a <eos> if the <eos> is
555+
already at the end, is likely a user error, and since we know Qwen likes to
556+
have <eos>\n we'll check for that case.
557+
558+
This makes the logic slightly more robust to the model family's chat template
559+
so users don't need to know whether they need to add add_eos or not.
560+
"""
561+
stripped_message_chunk = message_chunk.rstrip("\n")
552562
if add_eos_token:
553563
if tokenizer.eos_token is None:
554564
warnings.warn(
555565
"add_eos_token is True but the tokenizer does not have an EOS token. Skipping EOS token addition."
556566
)
557-
elif not message_chunk.endswith(tokenizer.eos_token):
567+
elif not stripped_message_chunk.endswith(tokenizer.eos_token):
558568
message_chunk += tokenizer.eos_token
559569

560570
# get images too (extend this for other modalities)

tests/unit/data/test_llm_message_utils.py

Lines changed: 114 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
from typing import Any, Callable
17+
1618
import pytest
1719
import torch
1820
from PIL import Image
@@ -329,177 +331,124 @@ def test_batch_pad_message_log_custom_pad_value(
329331
)
330332

331333

332-
@pytest.mark.hf_gated
333-
def test_get_formatted_message_log_llama(
334-
raw_chat_message_log: LLMMessageLogType,
335-
) -> None:
336-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
337-
338-
## get expected result
339-
formatted_system_message = tokenizer.apply_chat_template(
340-
[raw_chat_message_log[0]],
341-
tokenize=False,
342-
add_generation_prompt=False,
343-
add_special_tokens=False,
344-
)
345-
formatted_user_message = tokenizer.apply_chat_template(
346-
[raw_chat_message_log[1]],
347-
tokenize=False,
348-
add_generation_prompt=False,
349-
add_special_tokens=False,
350-
)
351-
formatted_assistant_message = tokenizer.apply_chat_template(
352-
[raw_chat_message_log[2]],
353-
tokenize=False,
354-
add_generation_prompt=False,
355-
add_special_tokens=False,
356-
)
357-
358-
## text should be equivalent to if we apply chat template
359-
## to each turn separately and manually remove the bot string
360-
## from the intermediate turns
361-
bot_str = "<|begin_of_text|>"
362-
expected_text = [
363-
formatted_system_message,
364-
formatted_user_message[len(bot_str) :],
365-
formatted_assistant_message[len(bot_str) :],
366-
]
367-
368-
task_data_spec = TaskDataSpec(
369-
task_name="test",
370-
)
371-
result = get_formatted_message_log(raw_chat_message_log, tokenizer, task_data_spec)
372-
actual_text = [m["content"] for m in result]
373-
374-
assert actual_text == expected_text
375-
376-
377-
@pytest.mark.hf_gated
378-
def test_get_formatted_message_log_add_generation_prompt_llama(
334+
@pytest.mark.parametrize(
335+
"model_id, chat_log_transform",
336+
[
337+
pytest.param(
338+
"meta-llama/Meta-Llama-3-8B-Instruct",
339+
lambda raw: raw,
340+
marks=pytest.mark.hf_gated,
341+
id="llama",
342+
),
343+
pytest.param(
344+
"google/gemma-3-27b-it",
345+
# Some Gemma chat templates (or versions) raise on system turns.
346+
# For portability across environments, test on user+assistant only.
347+
# If your tokenizer supports system turns, you can change this to `lambda raw: raw`.
348+
lambda raw: [raw[1], raw[2]],
349+
marks=pytest.mark.hf_gated,
350+
id="gemma",
351+
),
352+
pytest.param(
353+
"Qwen/Qwen2.5-Coder-32B-Instruct",
354+
lambda raw: raw,
355+
id="qwen",
356+
),
357+
],
358+
)
359+
@pytest.mark.parametrize("add_generation_prompt", [False, True])
360+
def test_get_formatted_message_log_models(
379361
raw_chat_message_log: LLMMessageLogType,
362+
model_id: str,
363+
chat_log_transform: Callable[[Any], Any],
364+
add_generation_prompt: bool,
380365
) -> None:
381-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
382-
383-
## get expected result
384-
formatted_system_message = tokenizer.apply_chat_template(
385-
[raw_chat_message_log[0]],
386-
tokenize=False,
387-
add_generation_prompt=False,
388-
add_special_tokens=False,
389-
)
390-
formatted_user_message = tokenizer.apply_chat_template(
391-
[raw_chat_message_log[1]],
392-
tokenize=False,
393-
add_generation_prompt=True,
394-
add_special_tokens=False,
395-
)
396-
formatted_assistant_message = (
397-
raw_chat_message_log[2]["content"] + tokenizer.eos_token
398-
)
399-
400-
## text should be equivalent to if we apply chat template
401-
## to each turn separately and manually remove the bot string
402-
## from the intermediate turns
403-
bot_str = "<|begin_of_text|>"
404-
expected_text = [
405-
formatted_system_message,
406-
formatted_user_message[len(bot_str) :],
407-
formatted_assistant_message,
408-
]
409-
410-
task_data_spec = TaskDataSpec(
411-
task_name="test",
412-
)
366+
"""Validate that get_formatted_message_log produces text consistent with the
367+
tokenizer's chat template across models.
368+
369+
This test is parametrized over model/tokenizer and whether to include a
370+
generation prompt. For models like Gemma that error on system turns, the
371+
input chat log is transformed to exclude the system message.
372+
373+
Expectations:
374+
- Require an EOS token for well-defined end-of-turn comparison.
375+
- When add_generation_prompt is False, the concatenated contents must match
376+
the tokenizer's apply_chat_template output; if the tokenizer omits a final
377+
EOS, accept the actual with EOS by appending EOS to the expected before
378+
comparison.
379+
- When add_generation_prompt is True and the last turn is an assistant
380+
message, accept either:
381+
(1) prefix built with add_generation_prompt=True followed by the raw
382+
assistant content plus EOS; or
383+
(2) the tokenizer's full non-generation template output plus EOS.
384+
This avoids hard-coding model-specific headers or delimiters while still
385+
verifying semantic equivalence.
386+
- Only normalization performed is trimming a trailing newline after EOS.
387+
"""
388+
tokenizer = AutoTokenizer.from_pretrained(model_id)
389+
chat_log = chat_log_transform(raw_chat_message_log)
390+
# Ensure tokenizer defines an EOS token; otherwise the test logic is ill-defined
391+
assert tokenizer.eos_token, "Tokenizer must define eos_token for this test"
392+
eos = tokenizer.eos_token
393+
task_data_spec = TaskDataSpec(task_name="test")
413394
result = get_formatted_message_log(
414-
raw_chat_message_log,
395+
chat_log,
415396
tokenizer,
416397
task_data_spec,
417-
add_generation_prompt=True,
418-
)
419-
actual_text = [m["content"] for m in result]
420-
421-
assert actual_text == expected_text
422-
423-
424-
def test_get_formatted_message_log_qwen(
425-
raw_chat_message_log: LLMMessageLogType,
426-
) -> None:
427-
## test using a tokenizer that does not have a bos token
428-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-32B-Instruct")
429-
assert tokenizer.bos_token is None
430-
431-
## get expected result
432-
## result is equivalent to if we apply chat template to the full message log,
433-
## remove the trailing newline, and then partition by the delimiter
434-
expected_text_string = tokenizer.apply_chat_template(
435-
[raw_chat_message_log],
436-
tokenize=False,
437-
add_generation_prompt=False,
438-
add_special_tokens=False,
439-
)[0].rstrip("\n") ## remove trailing newline
440-
441-
delimiter = "<|im_end|>\n"
442-
split_text = expected_text_string.split(delimiter)
443-
expected_text = []
444-
for i in range(len(split_text)):
445-
if i == len(raw_chat_message_log) - 1:
446-
expected_text.append(split_text[i])
398+
add_generation_prompt=add_generation_prompt,
399+
)
400+
actual_concat = "".join(m["content"] for m in result)
401+
402+
def normalize(s: str) -> str:
403+
# Normalize EOS+newline quirk to EOS only
404+
if s.endswith(eos + "\n"):
405+
return s[:-1]
406+
return s
407+
408+
if not add_generation_prompt:
409+
expected_concat = tokenizer.apply_chat_template(
410+
[chat_log],
411+
tokenize=False,
412+
add_generation_prompt=False,
413+
add_special_tokens=False,
414+
)[0]
415+
# Accept EOS presence even if the tokenizer's template omits it
416+
if actual_concat.endswith(eos) and not expected_concat.endswith(eos):
417+
expected_concat = expected_concat + eos
418+
assert normalize(actual_concat) == normalize(expected_concat)
419+
else:
420+
if len(chat_log) > 0 and chat_log[-1].get("role") == "assistant":
421+
prefix_log = chat_log[:-1]
422+
# Some tokenizers include a role header when add_generation_prompt=True.
423+
# Accept either behavior without hard-coding model-specific strings.
424+
prefix_gen = tokenizer.apply_chat_template(
425+
[prefix_log],
426+
tokenize=False,
427+
add_generation_prompt=True,
428+
add_special_tokens=False,
429+
)[0]
430+
assistant_suffix = chat_log[-1]["content"] + eos
431+
expected_concat_a = prefix_gen + assistant_suffix
432+
# Alternative: take the full non-generation template output and just append EOS
433+
full_no_gen = tokenizer.apply_chat_template(
434+
[chat_log],
435+
tokenize=False,
436+
add_generation_prompt=False,
437+
add_special_tokens=False,
438+
)[0]
439+
expected_concat_b = full_no_gen + eos
440+
actual_norm = normalize(actual_concat)
441+
assert actual_norm == normalize(
442+
expected_concat_a
443+
) or actual_norm == normalize(expected_concat_b)
447444
else:
448-
expected_text.append(split_text[i] + delimiter)
449-
450-
task_data_spec = TaskDataSpec(
451-
task_name="test",
452-
)
453-
result = get_formatted_message_log(raw_chat_message_log, tokenizer, task_data_spec)
454-
actual_text = [m["content"] for m in result]
455-
456-
assert actual_text == expected_text
457-
458-
459-
def test_get_formatted_message_log_add_generation_prompt_qwen(
460-
raw_chat_message_log: LLMMessageLogType,
461-
) -> None:
462-
## test using a tokenizer that does not have a bos token
463-
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-32B-Instruct")
464-
assert tokenizer.bos_token is None
465-
466-
## get expected result
467-
## result is equivalent to if we apply chat template to the full message log,
468-
## remove the trailing newline, and then partition by the delimiter
469-
## Separately handle the last message because of the generation prompt
470-
expected_text_string = tokenizer.apply_chat_template(
471-
[raw_chat_message_log[:2]],
472-
tokenize=False,
473-
add_generation_prompt=True,
474-
add_special_tokens=False,
475-
)[0]
476-
477-
delimiter = "<|im_end|>\n"
478-
split_text = expected_text_string.split(delimiter, 1)
479-
expected_text = []
480-
for i in range(len(split_text)):
481-
if i == len(split_text) - 1:
482-
expected_text.append(split_text[i])
483-
else:
484-
expected_text.append(split_text[i] + delimiter)
485-
486-
formatted_assistant_message = (
487-
raw_chat_message_log[2]["content"] + tokenizer.eos_token
488-
)
489-
expected_text.append(formatted_assistant_message)
490-
491-
task_data_spec = TaskDataSpec(
492-
task_name="test",
493-
)
494-
result = get_formatted_message_log(
495-
raw_chat_message_log,
496-
tokenizer,
497-
task_data_spec,
498-
add_generation_prompt=True,
499-
)
500-
actual_text = [m["content"] for m in result]
501-
502-
assert actual_text == expected_text
445+
expected_concat = tokenizer.apply_chat_template(
446+
[chat_log],
447+
tokenize=False,
448+
add_generation_prompt=True,
449+
add_special_tokens=False,
450+
)[0]
451+
assert normalize(actual_concat) == normalize(expected_concat)
503452

504453

505454
@pytest.mark.hf_gated

0 commit comments

Comments
 (0)