-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .types import BaseDataset, BasePipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import abc | ||
import random | ||
|
||
from typing import List | ||
from repepo.core.types import Completion, Example | ||
|
||
class AbstractFormatter(abc.ABC): | ||
""" Describes how to format examples as completions """ | ||
|
||
@abc.abstractmethod | ||
def apply(self, example: Example, **kwargs) -> str: | ||
raise NotImplementedError() | ||
|
||
def apply_list(self, examples: List[Example]) -> List[Completion]: | ||
completions = [] | ||
for example in examples: | ||
completion = self.apply(example) | ||
completions.append(completion) | ||
return completions | ||
|
||
class InputOutputFormatter(AbstractFormatter): | ||
""" Format as a simple input-output pair. """ | ||
|
||
PROMPT_TEMPLATE = ( | ||
"Input: {instruction} {input} \n" | ||
"Output: " | ||
) | ||
|
||
def apply(self, example: Example, **kwargs): | ||
del kwargs | ||
return Completion( | ||
prompt = self.PROMPT_TEMPLATE.format( | ||
instruction=example.instruction, | ||
input = example.input | ||
), | ||
response = example.output | ||
) | ||
|
||
class InstructionFormatter(AbstractFormatter): | ||
""" Instruction formatter used for fine-tuning Alpaca. """ | ||
|
||
PROMPT_INPUT: str = ( | ||
"Below is an instruction that describes a task, paired with an input that provides further context. " | ||
"Write a response that appropriately completes the request.\n\n" | ||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" | ||
) | ||
PROMPT_NO_INPUT: str = ( | ||
"Below is an instruction that describes a task. " | ||
"Write a response that appropriately completes the request.\n\n" | ||
"### Instruction:\n{instruction}\n\n### Response:" | ||
) | ||
|
||
def apply(self, example: Example, **kwargs): | ||
del kwargs | ||
if 'input' in example and bool(example['input']): | ||
prompt = self.PROMPT_INPUT.format_map(example) | ||
else: | ||
prompt = self.PROMPT_NO_INPUT.format_map(example) | ||
response = example['output'] | ||
return Completion( | ||
prompt = prompt, | ||
response = response | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import random | ||
import abc | ||
|
||
from typing import List | ||
from repepo.core.types import Completion | ||
|
||
def completion_to_str(completion: Completion) -> str: | ||
return completion['prompt'] + completion['response'] | ||
|
||
class AbstractPrompter(abc.ABC): | ||
""" Interface for modifying completions """ | ||
|
||
@abc.abstractmethod | ||
def apply(self, completion: Completion, **kwargs) -> Completion: | ||
raise NotImplementedError() | ||
|
||
def apply_list(self, completions: List[Completion]) -> List[Completion]: | ||
completions_out = [] | ||
for completion in completions: | ||
completion_out = self.apply(completion) | ||
completions_out.append(completion_out) | ||
return completions_out | ||
|
||
class IdentityPrompter(AbstractPrompter): | ||
""" Return the prompt as-is """ | ||
def apply(self, completion, **kwargs): | ||
del kwargs | ||
return completion | ||
|
||
class FewShotPrompter(AbstractPrompter): | ||
""" Compose examples few-shot """ | ||
|
||
def __init__(self, n_few_shot_examples: int = 2): | ||
self.n_few_shot_examples = n_few_shot_examples | ||
|
||
def apply(self, completion, few_shot_examples=[]): | ||
prompt = '\n'.join(few_shot_examples) + '\n' + completion['prompt'] | ||
response = completion['response'] | ||
return Completion( | ||
prompt=prompt, | ||
response=response | ||
) | ||
|
||
def apply_list(self, completions: List[Completion]) -> List[Completion]: | ||
|
||
output_completions = [] | ||
|
||
for i, completion in enumerate(completions): | ||
|
||
# Sample different completions for context | ||
few_shot_examples = [] | ||
selected_idxes = [i,] | ||
for _ in range(self.n_few_shot_examples): | ||
done = False | ||
while not done: | ||
idx = random.randint(0, len(completions) - 1) # range inclusive | ||
done = idx not in selected_idxes | ||
few_shot_examples.append( | ||
completion_to_str(completions[idx]) | ||
) | ||
selected_idxes.append(idx) | ||
|
||
# Concatenate completions | ||
output_completion = self.apply( | ||
completion, | ||
few_shot_examples=few_shot_examples | ||
) | ||
output_completions.append(output_completion) | ||
|
||
return output_completions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import abc | ||
from collections import namedtuple | ||
from typing import List, Dict, Any, NewType | ||
|
||
# Placeholder type definitions | ||
Model = NewType('Model', Any) | ||
Tokenizer = NewType('Tokenizer', Any) | ||
Prompter = NewType('Prompter', Any) | ||
Formatter = NewType('Formatter', Any) | ||
|
||
# Base types | ||
Example = namedtuple('Example', ('instruction', 'input', 'output')) | ||
Completion = namedtuple('Completion', ('prompt', 'response')) | ||
|
||
class BaseDataset(abc.ABC): | ||
|
||
@property | ||
def instruction(self) -> str: | ||
return self._instruction | ||
|
||
@property | ||
def examples(self) -> List[Example]: | ||
return self._examples | ||
|
||
class BasePipeline(abc.ABC): | ||
|
||
@property | ||
def model(self) -> Model: | ||
return self._model | ||
|
||
@property | ||
def tokenizer(self) -> Tokenizer: | ||
return self._tokenizer | ||
|
||
@property | ||
def prompter(self) -> Prompter: | ||
return self._prompter | ||
|
||
@property | ||
def formatter(self) -> Formatter: | ||
return self._formatter |