From 94c226edb34a3a49cc09ba0fc42d08794fcacc0c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 24 Jan 2025 12:00:06 -0500 Subject: [PATCH 1/2] fixes last eos token not in labels on basic use case --- src/axolotl/prompt_strategies/chat_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 5b12130d75..4a8993c821 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -487,7 +487,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]), + "roles_to_train": ds_cfg.get("roles_to_train", None), "train_on_eos": ds_cfg.get("train_on_eos", "turn"), } From 6c49083d8b5a1d1835c1bd037d8334de9bd2a877 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 24 Jan 2025 12:02:34 -0500 Subject: [PATCH 2/2] improve check for base case --- src/axolotl/prompt_strategies/chat_template.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 4a8993c821..27fae31b92 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -223,7 +223,7 @@ def messages(self, messages): def tokenize_prompt(self, prompt): # Old simple legacy behavior that works reliably. if ( - not self.roles_to_train + (not self.roles_to_train or self.roles_to_train == ["assistant"]) and not self.train_on_eos and not self.prompter.message_field_training and not self.prompter.message_field_training_detail @@ -487,7 +487,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", None), + "roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]), "train_on_eos": ds_cfg.get("train_on_eos", "turn"), }