From e37d9358e6f4432a1697b46fb42386ec8cc567c5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 13 Aug 2023 01:16:18 +0900 Subject: [PATCH] Fix(message): Improve error message for bad format (#365) --- src/axolotl/prompt_strategies/llama2_chat.py | 4 ++-- src/axolotl/prompters.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompt_strategies/llama2_chat.py b/src/axolotl/prompt_strategies/llama2_chat.py index a3fe599620..ba6845affa 100644 --- a/src/axolotl/prompt_strategies/llama2_chat.py +++ b/src/axolotl/prompt_strategies/llama2_chat.py @@ -29,7 +29,7 @@ from typing import Generator, List, Sequence from axolotl.prompt_tokenizers import PromptTokenizingStrategy -from axolotl.prompters import IGNORE_TOKEN_ID +from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE @dataclass @@ -190,7 +190,7 @@ def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]: conv.messages = [] # pylint: disable=R0801 for j, sentence in enumerate(source): role = roles[sentence["from"]] - assert role == conv.roles[j % 2] + assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE if sentence["value"]: conv.append_message(role, sentence["value"]) yield conv diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index a304bd1370..2da4ff1124 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -260,6 +260,11 @@ def append_message(self, role, message): self.messages.append([role, message]) +SHAREGPT_ASSERTION_FAILED_ROLE = ( + "Role did not alternate between turns (gpt and human). Please check your data." +) + + class ShareGPTPrompter: # pylint: disable=too-few-public-methods """ A prompter that generates prompts for the ShareGPT @@ -316,7 +321,7 @@ def build_prompt(self, source) -> Generator[str, None, None]: conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] - assert role == conv.roles[j % 2] + assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE conv.append_message(role, sentence["value"]) for part in conv.get_prompt():