Skip to content

Commit

Permalink
add cut_line_residual option
Browse files Browse the repository at this point in the history
strip extra `\n` in FIM completion
use more robust FIM scheme (suffix + prefix + middle)
  • Loading branch information
JegernOUTT committed Aug 12, 2023
1 parent 54a3e9b commit ba32d49
Showing 1 changed file with 29 additions and 40 deletions.
69 changes: 29 additions & 40 deletions refact_scratchpads/scratchpad_hf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch as th
import time
import json
import os

from refact_scratchpads.scratchpad_utils import trim_context_infill

Expand Down Expand Up @@ -140,13 +141,23 @@ def completion(self, final: bool):


class ScratchpadHuggingface(ScratchpadHuggingfaceBase):
def __init__(self, sources: Dict[str, str], cursor_file: str, cursor0: int, cursor1: int, **kwargs):

def __init__(
self,
sources: Dict[str, str],
cursor_file: str,
cursor0: int,
cursor1: int,
cut_line_residual: bool = False
**kwargs
):
super().__init__(**kwargs)

assert cursor0 == cursor1

self._cursor_file = cursor_file
self._cursor = cursor0
self._cut_line_residual = cut_line_residual
self._code = sources[cursor_file]

self._prefix: Optional[str] = None
Expand All @@ -158,51 +169,28 @@ def __init__(self, sources: Dict[str, str], cursor_file: str, cursor0: int, curs
self._fim_suffix = self._encode_one_token("<fim_suffix>")
self._fim_middle = self._encode_one_token("<fim_middle>")

def after_token_selection(self, m, chosen_token: th.Tensor, top_tokens: List[int] = [], **unused) -> Dict[str, Any]:
t = chosen_token.item()
self.debuglog("%05d %s" % (t, self._tokenizer.decode([t])))

if chosen_token in [self._tokenizer.eos_token_id] or self._tokenizer.eos_token_id in top_tokens:
self.finish_reason = "eot"

elif chosen_token in self._special_tokens:
self.finish_reason = "special-token"

if not self.finish_reason:
self._completion.append(t)
if chosen_token in self._stop_tokens:
self.finish_reason = "stoptoken"

t_str = self._tokenizer.decode([t])
if self._stop_lf and t_str.startswith("\n"):
self.finish_reason = "stop-lf"
if self._stop_lf_lf and t_str.startswith("\n\n"):
self.finish_reason = "stop-lflf"

self._tokens_produced += 1
if self._tokens_produced % 5 == 0:
self.needs_upload = True

return dict()

def prompt(self, T: int):
self._prefix = self._code[:self._cursor]
self._suffix = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
if self._cut_line_residual:
self._suffix = "".join(self._code[self._cursor:].splitlines(keepends=True)[1:])
else:
self._suffix = self._code[self._cursor:]
self._completion.clear()

prefix_cut, suffix_cut = trim_context_infill(
self._prefix, self._suffix, EncodingWrapper(self._tokenizer), T - self._max_tokens)
self.debuglog(f"ScratchpadHuggingfaceFIM prompt prefix %i chars, suffix %i chars, T=%i max_tokens=%i" % (
len(prefix_cut), len(suffix_cut), T, self._max_tokens))
a = self.encode_without_special_tokens(prefix_cut)
b = self.encode_without_special_tokens(suffix_cut)
assert self._fim_prefix not in a
assert self._fim_prefix not in b
self._prefix, self._suffix, EncodingWrapper(self._tokenizer), T - self._max_tokens
)
self.debuglog(
f"ScratchpadHuggingfaceFIM prompt prefix {len(prefix_cut)} chars, "
f"suffix {len(suffix_cut)} chars, T={T} max_tokens={self._max_tokens}"
)
prefix_cut_tokens = self.encode_without_special_tokens(prefix_cut)
suffix_cut_tokens = self.encode_without_special_tokens(suffix_cut)
prompt: List[int] = [
self._fim_prefix,
*a,
self._fim_suffix,
*b,
*prefix_cut_tokens,
self._fim_prefix,
*suffix_cut_tokens,
self._fim_middle,
]
# self.debuglog("-"*40)
Expand All @@ -213,8 +201,9 @@ def prompt(self, T: int):
def completion(self, final: bool):
assert self._prefix is not None
assert self._suffix is not None
completion = self._tokenizer.decode(self._completion).rstrip(os.linesep)
return {
self._cursor_file: self._prefix + self._tokenizer.decode(self._completion) + self._suffix,
self._cursor_file: self._prefix + completion + self._suffix,
}


Expand Down

0 comments on commit ba32d49

Please sign in to comment.