Skip to content

Commit

Permalink
Fix sharegpt prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed May 30, 2023
1 parent cfcc549 commit 25eeeeb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
15 changes: 8 additions & 7 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,15 +371,16 @@ def tokenize_prompt(self, prompt):
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "SYSTEM:":
part = part[1] # Ignore the system role from preamble
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
logging.warning(f"unhandled role: {part[0]}")
else:
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])

# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
Expand Down
10 changes: 5 additions & 5 deletions src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import logging
from enum import Enum, auto
from typing import Generator, List, Optional, Union
from typing import Generator, List, Optional, Tuple, Union

IGNORE_TOKEN_ID = -100

Expand Down Expand Up @@ -235,16 +235,16 @@ class Conversation:
sep: str = "###"
sep2: Optional[str] = None

def get_prompt(self) -> Generator[str, None, None]:
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield preamble
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield role + ":" + " " + message
yield (role + ":", " " + message)
else:
logging.warning(f"role with empty message: {role}")
yield role + ":"
yield (role + ":", "")

def copy(self):
return Conversation(
Expand Down

0 comments on commit 25eeeeb

Please sign in to comment.