| 
13 | 13 | # limitations under the License.  | 
14 | 14 | 
 
  | 
15 | 15 | 
 
  | 
 | 16 | +from typing import Any, Callable  | 
 | 17 | + | 
16 | 18 | import pytest  | 
17 | 19 | import torch  | 
18 | 20 | from PIL import Image  | 
@@ -329,177 +331,124 @@ def test_batch_pad_message_log_custom_pad_value(  | 
329 | 331 |     )  | 
330 | 332 | 
 
  | 
331 | 333 | 
 
  | 
332 |  | -@pytest.mark.hf_gated  | 
333 |  | -def test_get_formatted_message_log_llama(  | 
334 |  | -    raw_chat_message_log: LLMMessageLogType,  | 
335 |  | -) -> None:  | 
336 |  | -    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")  | 
337 |  | - | 
338 |  | -    ## get expected result  | 
339 |  | -    formatted_system_message = tokenizer.apply_chat_template(  | 
340 |  | -        [raw_chat_message_log[0]],  | 
341 |  | -        tokenize=False,  | 
342 |  | -        add_generation_prompt=False,  | 
343 |  | -        add_special_tokens=False,  | 
344 |  | -    )  | 
345 |  | -    formatted_user_message = tokenizer.apply_chat_template(  | 
346 |  | -        [raw_chat_message_log[1]],  | 
347 |  | -        tokenize=False,  | 
348 |  | -        add_generation_prompt=False,  | 
349 |  | -        add_special_tokens=False,  | 
350 |  | -    )  | 
351 |  | -    formatted_assistant_message = tokenizer.apply_chat_template(  | 
352 |  | -        [raw_chat_message_log[2]],  | 
353 |  | -        tokenize=False,  | 
354 |  | -        add_generation_prompt=False,  | 
355 |  | -        add_special_tokens=False,  | 
356 |  | -    )  | 
357 |  | - | 
358 |  | -    ## text should be equivalent to if we apply chat template  | 
359 |  | -    ## to each turn separately and manually remove the bot string  | 
360 |  | -    ## from the intermediate turns  | 
361 |  | -    bot_str = "<|begin_of_text|>"  | 
362 |  | -    expected_text = [  | 
363 |  | -        formatted_system_message,  | 
364 |  | -        formatted_user_message[len(bot_str) :],  | 
365 |  | -        formatted_assistant_message[len(bot_str) :],  | 
366 |  | -    ]  | 
367 |  | - | 
368 |  | -    task_data_spec = TaskDataSpec(  | 
369 |  | -        task_name="test",  | 
370 |  | -    )  | 
371 |  | -    result = get_formatted_message_log(raw_chat_message_log, tokenizer, task_data_spec)  | 
372 |  | -    actual_text = [m["content"] for m in result]  | 
373 |  | - | 
374 |  | -    assert actual_text == expected_text  | 
375 |  | - | 
376 |  | - | 
377 |  | -@pytest.mark.hf_gated  | 
378 |  | -def test_get_formatted_message_log_add_generation_prompt_llama(  | 
 | 334 | +@pytest.mark.parametrize(  | 
 | 335 | +    "model_id, chat_log_transform",  | 
 | 336 | +    [  | 
 | 337 | +        pytest.param(  | 
 | 338 | +            "meta-llama/Meta-Llama-3-8B-Instruct",  | 
 | 339 | +            lambda raw: raw,  | 
 | 340 | +            marks=pytest.mark.hf_gated,  | 
 | 341 | +            id="llama",  | 
 | 342 | +        ),  | 
 | 343 | +        pytest.param(  | 
 | 344 | +            "google/gemma-3-27b-it",  | 
 | 345 | +            # Some Gemma chat templates (or versions) raise on system turns.  | 
 | 346 | +            # For portability across environments, test on user+assistant only.  | 
 | 347 | +            # If your tokenizer supports system turns, you can change this to `lambda raw: raw`.  | 
 | 348 | +            lambda raw: [raw[1], raw[2]],  | 
 | 349 | +            marks=pytest.mark.hf_gated,  | 
 | 350 | +            id="gemma",  | 
 | 351 | +        ),  | 
 | 352 | +        pytest.param(  | 
 | 353 | +            "Qwen/Qwen2.5-Coder-32B-Instruct",  | 
 | 354 | +            lambda raw: raw,  | 
 | 355 | +            id="qwen",  | 
 | 356 | +        ),  | 
 | 357 | +    ],  | 
 | 358 | +)  | 
 | 359 | +@pytest.mark.parametrize("add_generation_prompt", [False, True])  | 
 | 360 | +def test_get_formatted_message_log_models(  | 
379 | 361 |     raw_chat_message_log: LLMMessageLogType,  | 
 | 362 | +    model_id: str,  | 
 | 363 | +    chat_log_transform: Callable[[Any], Any],  | 
 | 364 | +    add_generation_prompt: bool,  | 
380 | 365 | ) -> None:  | 
381 |  | -    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")  | 
382 |  | - | 
383 |  | -    ## get expected result  | 
384 |  | -    formatted_system_message = tokenizer.apply_chat_template(  | 
385 |  | -        [raw_chat_message_log[0]],  | 
386 |  | -        tokenize=False,  | 
387 |  | -        add_generation_prompt=False,  | 
388 |  | -        add_special_tokens=False,  | 
389 |  | -    )  | 
390 |  | -    formatted_user_message = tokenizer.apply_chat_template(  | 
391 |  | -        [raw_chat_message_log[1]],  | 
392 |  | -        tokenize=False,  | 
393 |  | -        add_generation_prompt=True,  | 
394 |  | -        add_special_tokens=False,  | 
395 |  | -    )  | 
396 |  | -    formatted_assistant_message = (  | 
397 |  | -        raw_chat_message_log[2]["content"] + tokenizer.eos_token  | 
398 |  | -    )  | 
399 |  | - | 
400 |  | -    ## text should be equivalent to if we apply chat template  | 
401 |  | -    ## to each turn separately and manually remove the bot string  | 
402 |  | -    ## from the intermediate turns  | 
403 |  | -    bot_str = "<|begin_of_text|>"  | 
404 |  | -    expected_text = [  | 
405 |  | -        formatted_system_message,  | 
406 |  | -        formatted_user_message[len(bot_str) :],  | 
407 |  | -        formatted_assistant_message,  | 
408 |  | -    ]  | 
409 |  | - | 
410 |  | -    task_data_spec = TaskDataSpec(  | 
411 |  | -        task_name="test",  | 
412 |  | -    )  | 
 | 366 | +    """Validate that get_formatted_message_log produces text consistent with the  | 
 | 367 | +    tokenizer's chat template across models.  | 
 | 368 | +
  | 
 | 369 | +    This test is parametrized over model/tokenizer and whether to include a  | 
 | 370 | +    generation prompt. For models like Gemma that error on system turns, the  | 
 | 371 | +    input chat log is transformed to exclude the system message.  | 
 | 372 | +
  | 
 | 373 | +    Expectations:  | 
 | 374 | +    - Require an EOS token for well-defined end-of-turn comparison.  | 
 | 375 | +    - When add_generation_prompt is False, the concatenated contents must match  | 
 | 376 | +      the tokenizer's apply_chat_template output; if the tokenizer omits a final  | 
 | 377 | +      EOS, accept the actual with EOS by appending EOS to the expected before  | 
 | 378 | +      comparison.  | 
 | 379 | +    - When add_generation_prompt is True and the last turn is an assistant  | 
 | 380 | +      message, accept either:  | 
 | 381 | +        (1) prefix built with add_generation_prompt=True followed by the raw  | 
 | 382 | +            assistant content plus EOS; or  | 
 | 383 | +        (2) the tokenizer's full non-generation template output plus EOS.  | 
 | 384 | +      This avoids hard-coding model-specific headers or delimiters while still  | 
 | 385 | +      verifying semantic equivalence.  | 
 | 386 | +    - Only normalization performed is trimming a trailing newline after EOS.  | 
 | 387 | +    """  | 
 | 388 | +    tokenizer = AutoTokenizer.from_pretrained(model_id)  | 
 | 389 | +    chat_log = chat_log_transform(raw_chat_message_log)  | 
 | 390 | +    # Ensure tokenizer defines an EOS token; otherwise the test logic is ill-defined  | 
 | 391 | +    assert tokenizer.eos_token, "Tokenizer must define eos_token for this test"  | 
 | 392 | +    eos = tokenizer.eos_token  | 
 | 393 | +    task_data_spec = TaskDataSpec(task_name="test")  | 
413 | 394 |     result = get_formatted_message_log(  | 
414 |  | -        raw_chat_message_log,  | 
 | 395 | +        chat_log,  | 
415 | 396 |         tokenizer,  | 
416 | 397 |         task_data_spec,  | 
417 |  | -        add_generation_prompt=True,  | 
418 |  | -    )  | 
419 |  | -    actual_text = [m["content"] for m in result]  | 
420 |  | - | 
421 |  | -    assert actual_text == expected_text  | 
422 |  | - | 
423 |  | - | 
424 |  | -def test_get_formatted_message_log_qwen(  | 
425 |  | -    raw_chat_message_log: LLMMessageLogType,  | 
426 |  | -) -> None:  | 
427 |  | -    ## test using a tokenizer that does not have a bos token  | 
428 |  | -    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-32B-Instruct")  | 
429 |  | -    assert tokenizer.bos_token is None  | 
430 |  | - | 
431 |  | -    ## get expected result  | 
432 |  | -    ## result is equivalent to if we apply chat template to the full message log,  | 
433 |  | -    ## remove the trailing newline, and then partition by the delimiter  | 
434 |  | -    expected_text_string = tokenizer.apply_chat_template(  | 
435 |  | -        [raw_chat_message_log],  | 
436 |  | -        tokenize=False,  | 
437 |  | -        add_generation_prompt=False,  | 
438 |  | -        add_special_tokens=False,  | 
439 |  | -    )[0].rstrip("\n")  ## remove trailing newline  | 
440 |  | - | 
441 |  | -    delimiter = "<|im_end|>\n"  | 
442 |  | -    split_text = expected_text_string.split(delimiter)  | 
443 |  | -    expected_text = []  | 
444 |  | -    for i in range(len(split_text)):  | 
445 |  | -        if i == len(raw_chat_message_log) - 1:  | 
446 |  | -            expected_text.append(split_text[i])  | 
 | 398 | +        add_generation_prompt=add_generation_prompt,  | 
 | 399 | +    )  | 
 | 400 | +    actual_concat = "".join(m["content"] for m in result)  | 
 | 401 | + | 
 | 402 | +    def normalize(s: str) -> str:  | 
 | 403 | +        # Normalize EOS+newline quirk to EOS only  | 
 | 404 | +        if s.endswith(eos + "\n"):  | 
 | 405 | +            return s[:-1]  | 
 | 406 | +        return s  | 
 | 407 | + | 
 | 408 | +    if not add_generation_prompt:  | 
 | 409 | +        expected_concat = tokenizer.apply_chat_template(  | 
 | 410 | +            [chat_log],  | 
 | 411 | +            tokenize=False,  | 
 | 412 | +            add_generation_prompt=False,  | 
 | 413 | +            add_special_tokens=False,  | 
 | 414 | +        )[0]  | 
 | 415 | +        # Accept EOS presence even if the tokenizer's template omits it  | 
 | 416 | +        if actual_concat.endswith(eos) and not expected_concat.endswith(eos):  | 
 | 417 | +            expected_concat = expected_concat + eos  | 
 | 418 | +        assert normalize(actual_concat) == normalize(expected_concat)  | 
 | 419 | +    else:  | 
 | 420 | +        if len(chat_log) > 0 and chat_log[-1].get("role") == "assistant":  | 
 | 421 | +            prefix_log = chat_log[:-1]  | 
 | 422 | +            # Some tokenizers include a role header when add_generation_prompt=True.  | 
 | 423 | +            # Accept either behavior without hard-coding model-specific strings.  | 
 | 424 | +            prefix_gen = tokenizer.apply_chat_template(  | 
 | 425 | +                [prefix_log],  | 
 | 426 | +                tokenize=False,  | 
 | 427 | +                add_generation_prompt=True,  | 
 | 428 | +                add_special_tokens=False,  | 
 | 429 | +            )[0]  | 
 | 430 | +            assistant_suffix = chat_log[-1]["content"] + eos  | 
 | 431 | +            expected_concat_a = prefix_gen + assistant_suffix  | 
 | 432 | +            # Alternative: take the full non-generation template output and just append EOS  | 
 | 433 | +            full_no_gen = tokenizer.apply_chat_template(  | 
 | 434 | +                [chat_log],  | 
 | 435 | +                tokenize=False,  | 
 | 436 | +                add_generation_prompt=False,  | 
 | 437 | +                add_special_tokens=False,  | 
 | 438 | +            )[0]  | 
 | 439 | +            expected_concat_b = full_no_gen + eos  | 
 | 440 | +            actual_norm = normalize(actual_concat)  | 
 | 441 | +            assert actual_norm == normalize(  | 
 | 442 | +                expected_concat_a  | 
 | 443 | +            ) or actual_norm == normalize(expected_concat_b)  | 
447 | 444 |         else:  | 
448 |  | -            expected_text.append(split_text[i] + delimiter)  | 
449 |  | - | 
450 |  | -    task_data_spec = TaskDataSpec(  | 
451 |  | -        task_name="test",  | 
452 |  | -    )  | 
453 |  | -    result = get_formatted_message_log(raw_chat_message_log, tokenizer, task_data_spec)  | 
454 |  | -    actual_text = [m["content"] for m in result]  | 
455 |  | - | 
456 |  | -    assert actual_text == expected_text  | 
457 |  | - | 
458 |  | - | 
459 |  | -def test_get_formatted_message_log_add_generation_prompt_qwen(  | 
460 |  | -    raw_chat_message_log: LLMMessageLogType,  | 
461 |  | -) -> None:  | 
462 |  | -    ## test using a tokenizer that does not have a bos token  | 
463 |  | -    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-32B-Instruct")  | 
464 |  | -    assert tokenizer.bos_token is None  | 
465 |  | - | 
466 |  | -    ## get expected result  | 
467 |  | -    ## result is equivalent to if we apply chat template to the full message log,  | 
468 |  | -    ## remove the trailing newline, and then partition by the delimiter  | 
469 |  | -    ## Separately handle the last message because of the generation prompt  | 
470 |  | -    expected_text_string = tokenizer.apply_chat_template(  | 
471 |  | -        [raw_chat_message_log[:2]],  | 
472 |  | -        tokenize=False,  | 
473 |  | -        add_generation_prompt=True,  | 
474 |  | -        add_special_tokens=False,  | 
475 |  | -    )[0]  | 
476 |  | - | 
477 |  | -    delimiter = "<|im_end|>\n"  | 
478 |  | -    split_text = expected_text_string.split(delimiter, 1)  | 
479 |  | -    expected_text = []  | 
480 |  | -    for i in range(len(split_text)):  | 
481 |  | -        if i == len(split_text) - 1:  | 
482 |  | -            expected_text.append(split_text[i])  | 
483 |  | -        else:  | 
484 |  | -            expected_text.append(split_text[i] + delimiter)  | 
485 |  | - | 
486 |  | -    formatted_assistant_message = (  | 
487 |  | -        raw_chat_message_log[2]["content"] + tokenizer.eos_token  | 
488 |  | -    )  | 
489 |  | -    expected_text.append(formatted_assistant_message)  | 
490 |  | - | 
491 |  | -    task_data_spec = TaskDataSpec(  | 
492 |  | -        task_name="test",  | 
493 |  | -    )  | 
494 |  | -    result = get_formatted_message_log(  | 
495 |  | -        raw_chat_message_log,  | 
496 |  | -        tokenizer,  | 
497 |  | -        task_data_spec,  | 
498 |  | -        add_generation_prompt=True,  | 
499 |  | -    )  | 
500 |  | -    actual_text = [m["content"] for m in result]  | 
501 |  | - | 
502 |  | -    assert actual_text == expected_text  | 
 | 445 | +            expected_concat = tokenizer.apply_chat_template(  | 
 | 446 | +                [chat_log],  | 
 | 447 | +                tokenize=False,  | 
 | 448 | +                add_generation_prompt=True,  | 
 | 449 | +                add_special_tokens=False,  | 
 | 450 | +            )[0]  | 
 | 451 | +            assert normalize(actual_concat) == normalize(expected_concat)  | 
503 | 452 | 
 
  | 
504 | 453 | 
 
  | 
505 | 454 | @pytest.mark.hf_gated  | 
 | 
0 commit comments