Skip to content

Commit

Permalink
Add basic abstractions
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Nov 15, 2023
1 parent 82feb1b commit b919229
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
1 change: 1 addition & 0 deletions repepo/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .types import BaseDataset, BasePipeline
63 changes: 63 additions & 0 deletions repepo/core/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import abc
import random

from typing import List
from repepo.core.types import Completion, Example

class AbstractFormatter(abc.ABC):
""" Describes how to format examples as completions """

@abc.abstractmethod
def apply(self, example: Example, **kwargs) -> str:
raise NotImplementedError()

def apply_list(self, examples: List[Example]) -> List[Completion]:
completions = []
for example in examples:
completion = self.apply(example)
completions.append(completion)
return completions

class InputOutputFormatter(AbstractFormatter):
""" Format as a simple input-output pair. """

PROMPT_TEMPLATE = (
"Input: {instruction} {input} \n"
"Output: "
)

def apply(self, example: Example, **kwargs):
del kwargs
return Completion(
prompt = self.PROMPT_TEMPLATE.format(
instruction=example.instruction,
input = example.input
),
response = example.output
)

class InstructionFormatter(AbstractFormatter):
""" Instruction formatter used for fine-tuning Alpaca. """

PROMPT_INPUT: str = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
)
PROMPT_NO_INPUT: str = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
)

def apply(self, example: Example, **kwargs):
del kwargs
if 'input' in example and bool(example['input']):
prompt = self.PROMPT_INPUT.format_map(example)
else:
prompt = self.PROMPT_NO_INPUT.format_map(example)
response = example['output']
return Completion(
prompt = prompt,
response = response
)
70 changes: 70 additions & 0 deletions repepo/core/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import random
import abc

from typing import List
from repepo.core.types import Completion

def completion_to_str(completion: Completion) -> str:
return completion['prompt'] + completion['response']

class AbstractPrompter(abc.ABC):
""" Interface for modifying completions """

@abc.abstractmethod
def apply(self, completion: Completion, **kwargs) -> Completion:
raise NotImplementedError()

def apply_list(self, completions: List[Completion]) -> List[Completion]:
completions_out = []
for completion in completions:
completion_out = self.apply(completion)
completions_out.append(completion_out)
return completions_out

class IdentityPrompter(AbstractPrompter):
""" Return the prompt as-is """
def apply(self, completion, **kwargs):
del kwargs
return completion

class FewShotPrompter(AbstractPrompter):
""" Compose examples few-shot """

def __init__(self, n_few_shot_examples: int = 2):
self.n_few_shot_examples = n_few_shot_examples

def apply(self, completion, few_shot_examples=[]):
prompt = '\n'.join(few_shot_examples) + '\n' + completion['prompt']
response = completion['response']
return Completion(
prompt=prompt,
response=response
)

def apply_list(self, completions: List[Completion]) -> List[Completion]:

output_completions = []

for i, completion in enumerate(completions):

# Sample different completions for context
few_shot_examples = []
selected_idxes = [i,]
for _ in range(self.n_few_shot_examples):
done = False
while not done:
idx = random.randint(0, len(completions) - 1) # range inclusive
done = idx not in selected_idxes
few_shot_examples.append(
completion_to_str(completions[idx])
)
selected_idxes.append(idx)

# Concatenate completions
output_completion = self.apply(
completion,
few_shot_examples=few_shot_examples
)
output_completions.append(output_completion)

return output_completions
41 changes: 41 additions & 0 deletions repepo/core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import abc
from collections import namedtuple
from typing import List, Dict, Any, NewType

# Placeholder type definitions
Model = NewType('Model', Any)
Tokenizer = NewType('Tokenizer', Any)
Prompter = NewType('Prompter', Any)
Formatter = NewType('Formatter', Any)

# Base types
Example = namedtuple('Example', ('instruction', 'input', 'output'))
Completion = namedtuple('Completion', ('prompt', 'response'))

class BaseDataset(abc.ABC):

@property
def instruction(self) -> str:
return self._instruction

@property
def examples(self) -> List[Example]:
return self._examples

class BasePipeline(abc.ABC):

@property
def model(self) -> Model:
return self._model

@property
def tokenizer(self) -> Tokenizer:
return self._tokenizer

@property
def prompter(self) -> Prompter:
return self._prompter

@property
def formatter(self) -> Formatter:
return self._formatter

0 comments on commit b919229

Please sign in to comment.