Skip to content

Commit

Permalink
code llama scratchpad and 7b 8bit model
Browse files Browse the repository at this point in the history
  • Loading branch information
mitya52 committed Aug 28, 2023
1 parent 0c3a0fd commit ecc7ede
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
11 changes: 11 additions & 0 deletions known_models_db/refact_known_models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@
"required_memory_mb": 14000,
"filter_caps": ["llama2"],
},
"codellama/7b": {
"backend": "transformers",
"model_path": "TheBloke/CodeLlama-7B-fp16",
"diff_scratchpad_class": "refact_scratchpads:ScratchpadCodeLlama",
"chat_scratchpad_class": None,
"model_class_kwargs": {
"load_in_8bit": True,
},
"required_memory_mb": 14000,
"filter_caps": ["completion"],
},
"wizardlm/30b/4bit": {
"backend": "transformers",
"model_path": "TheBloke/WizardLM-30B-fp16",
Expand Down
1 change: 1 addition & 0 deletions refact_scratchpads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceBase
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceCompletion
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingface
from refact_scratchpads.scratchpad_hf import ScratchpadCodeLlama
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceStarChat
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceWizard
from refact_scratchpads.scratchpad_hf import ScratchpadHuggingfaceLlama2
Expand Down
49 changes: 49 additions & 0 deletions refact_scratchpads/scratchpad_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,55 @@ def completion(self, final: bool):
}


class ScratchpadCodeLlama(ScratchpadHuggingfaceBase):

def __init__(self, sources: Dict[str, str], cursor_file: str, cursor0: int, cursor1: int, **kwargs):
super().__init__(**kwargs)

assert cursor0 == cursor1

self._cursor_file = cursor_file
self._cursor = cursor0
self._code = sources[cursor_file]

self._prefix: Optional[str] = None
self._suffix: Optional[str] = None
self._completion = []

self._tokens_produced = 0
self._fim_prefix = self._encode_one_token("<PRE>")
self._fim_suffix = self._encode_one_token("<SUF>")
self._fim_middle = self._encode_one_token("<MID>")
self._fim_eot = self._encode_one_token("<EOT>")
self._special_tokens.update({
self._fim_prefix, self._fim_suffix, self._fim_middle, self._fim_eot,
})

def prompt(self, T: int):
self._prefix = self._code[:self._cursor]
self._suffix = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
self._completion.clear()

prefix_cut, suffix_cut = trim_context_infill(
self._prefix, self._suffix, EncodingWrapper(self._tokenizer), T - self._max_tokens)
prompt: List[int] = [
self._eos_token,
self._fim_prefix,
*self._tokenizer.encode(prefix_cut),
self._fim_suffix,
*self._tokenizer.encode(suffix_cut),
self._fim_middle,
]
return prompt

def completion(self, final: bool):
assert self._prefix is not None
assert self._suffix is not None
return {
self._cursor_file: self._prefix + self._tokenizer.decode(self._completion) + self._suffix,
}


class ScratchpadChatBase(ScratchpadHuggingfaceBase):

def __init__(self, messages: List[Dict[str, str]], **kwargs):
Expand Down

0 comments on commit ecc7ede

Please sign in to comment.