Skip to content

Commit 3f9c953

Browse files
authoredMay 8, 2023
Merge pull request #15 from NanoCode012/feat/completion
Feat: Add Completion dataset type
2 parents bd3c5a5 + 174b74d commit 3f9c953

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed
 

‎src/axolotl/prompt_tokenizers.py

+19
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,25 @@ def parse_instruction_fields(self, prompt) -> (str, str, str):
125125
)
126126

127127

128+
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
129+
def parse_instruction_fields(self, prompt) -> (str):
130+
return (
131+
prompt["text"]
132+
)
133+
134+
def tokenize_prompt(self, prompt):
135+
instruction = self.parse_instruction_fields(prompt)
136+
full_prompt = self._build_full_prompt(instruction)
137+
tokenized_full_prompt = self._tokenize(full_prompt)
138+
139+
return tokenized_full_prompt
140+
141+
def _build_full_prompt(self, instruction):
142+
return self.prompter.build_prompt(
143+
instruction
144+
)
145+
146+
128147
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
129148
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
130149
raise NotImplementedError

‎src/axolotl/prompters.py

+11
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ class JeopardyPrompter(AlpacaPrompter):
3535
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
3636

3737

38+
class CompletionPrompter(AlpacaPrompter):
39+
def build_prompt(
40+
self,
41+
instruction: str
42+
) -> str:
43+
return instruction
44+
45+
def get_response(self, output: str) -> str:
46+
return output.strip()
47+
48+
3849
class GPTeacherPrompter(AlpacaPrompter):
3950
...
4051

‎src/axolotl/utils/data.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
GPTeacherPromptTokenizingStrategy,
1212
OpenAssistantPromptTokenizingStrategy,
1313
AlpacaReflectionPTStrategy,
14-
ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
14+
ShareGPTPromptTokenizingStrategy,
15+
JeopardyPromptTokenizingStrategy,
16+
CompletionPromptTokenizingStrategy,
1517
)
1618
from axolotl.prompters import (
1719
AlpacaPrompter,
1820
GPTeacherPrompter,
1921
ReflectAlpacaPrompter,
20-
ShareGPTPrompter, JeopardyPrompter,
22+
ShareGPTPrompter,
23+
JeopardyPrompter,
24+
CompletionPrompter,
2125
)
2226

2327

@@ -118,6 +122,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
118122
)
119123
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
120124
datasets.append(ds_wrapper)
125+
elif d.type == "completion":
126+
ds_strategy = CompletionPromptTokenizingStrategy(
127+
CompletionPrompter(),
128+
tokenizer,
129+
cfg.train_on_inputs,
130+
cfg.sequence_len,
131+
)
132+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
133+
datasets.append(ds_wrapper)
121134
else:
122135
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
123136
logging.info("tokenizing, merging, and shuffling master dataset")

0 commit comments

Comments
 (0)