diff --git a/chat/dialogues.py b/chat/dialogues.py index 634c4a1..31a7ab3 100644 --- a/chat/dialogues.py +++ b/chat/dialogues.py @@ -112,17 +112,17 @@ def _save_pretrained(self, save_directory: Union[str, Path]) -> None: @classmethod def _from_pretrained( - cls: Type[T], - *, - model_id: str, - revision: Optional[str], - cache_dir: Optional[Union[str, Path]], - force_download: bool, - proxies: Optional[Dict], - resume_download: bool, - local_files_only: bool, - token: Optional[Union[str, bool]], - **model_kwargs, + cls: Type[T], + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Optional[Union[str, bool]], + **model_kwargs, ) -> T: """Loads the dialogue template from a local directory or the Huggingface Hub. @@ -232,10 +232,11 @@ def prepare_dialogue(example, dialogue_template, is_train=True): def mask_user_labels(tokenizer, dialogue_template, labels): """Masks the user turns of a dialogue from the loss""" user_token_id = tokenizer.convert_tokens_to_ids(dialogue_template.user_token) + system_token_id = tokenizer.convert_tokens_to_ids(dialogue_template.system_token) assistant_token_id = tokenizer.convert_tokens_to_ids(dialogue_template.assistant_token) for idx, label_id in enumerate(labels): - if label_id == user_token_id: + if label_id in [user_token_id, system_token_id]: current_idx = idx - while labels[current_idx] != assistant_token_id and current_idx < len(labels): + while current_idx < len(labels) and labels[current_idx] != assistant_token_id: labels[current_idx] = IGNORE_INDEX current_idx += 1 diff --git a/chat/train.py b/chat/train.py index df383aa..23ed459 100644 --- a/chat/train.py +++ b/chat/train.py @@ -201,7 +201,8 @@ def group_texts(examples): for k, t in concatenated_examples.items() } labels = result["input_ids"].copy() - mask_user_labels(tokenizer, dialogue_template, labels) + for label in labels: + mask_user_labels(tokenizer, dialogue_template, label) result["labels"] = labels return result