From bfb99095ad229022efbbf7197d33137b4a433704 Mon Sep 17 00:00:00 2001 From: cOng Date: Thu, 1 Feb 2024 12:13:13 +0800 Subject: [PATCH] fix: tokenization mismatch --- fastchat/train/train_with_template.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fastchat/train/train_with_template.py b/fastchat/train/train_with_template.py index a947ffe87..4511a2108 100644 --- a/fastchat/train/train_with_template.py +++ b/fastchat/train/train_with_template.py @@ -132,9 +132,9 @@ def get_prompt_separator(conv): elif conv.sep_style == SeparatorStyle.CHATML: if conv.sep2 is None: - user_turn_separator = conv.sep + user_turn_separator = conv.sep + "\n" else: - user_turn_separator = conv.sep2 + user_turn_separator = conv.sep2 + "\n" assistant_turn_separator = conv.roles[1] + "\n" @@ -155,7 +155,9 @@ def mask_targets(conversations, targets, tokenizer, conv): user_turn_separator, assistant_turn_separator = get_prompt_separator(conv) turns = conversation.split(user_turn_separator) for i, turn in enumerate(turns): - if turn == "": + if ( + i < len(turns) - 1 and turn == "" + ): # Last turn is the user_turn_separator break if i != 0: