Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UnicodeDecodeError permanently #118

Merged
merged 9 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/low_level_api/low_level_api_chat_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, params: GptParams) -> None:
# tokenize a prompt
def _tokenize(self, prompt, bos=True):
_arr = (llama_cpp.llama_token * (len(prompt) + 1))()
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8"), _arr, len(_arr), bos)
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
return _arr[:_n]

def set_color(self, c):
Expand Down Expand Up @@ -342,7 +342,7 @@ def exit(self):
# return past text
def past(self):
for id in self.last_n_tokens[-self.n_past:]:
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")

# write input
def input(self, prompt: str):
Expand All @@ -356,7 +356,7 @@ def input(self, prompt: str):
def output(self):
self.remaining_tokens = self.params.n_predict
for id in self.generate():
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8")
yield llama_cpp.llama_token_to_str(self.ctx, id).decode("utf-8", errors="ignore")

# read user input
def read_input(self):
Expand Down
2 changes: 1 addition & 1 deletion examples/low_level_api/low_level_api_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
if not input_noecho:
for id in embd:
print(
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8"),
llama_cpp.llama_token_to_str(ctx, id).decode("utf-8", errors="ignore"),
end="",
flush=True,
)
Expand Down
25 changes: 20 additions & 5 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def _create_completion(
self.load_state(self.cache[prompt_tokens])

finish_reason = "length"
multibyte_fix = 0
for token in self.generate(
prompt_tokens,
top_k=top_k,
Expand All @@ -467,6 +468,20 @@ def _create_completion(
completion_tokens.append(token)

all_text = self.detokenize(completion_tokens)

# Contains multi-byte UTF8
for k,char in enumerate(all_text[-3:]):
k = 3 - k
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
if (num > k and pattern & char == pattern):
multibyte_fix = num - k

# Stop incomplete bytes from passing
if (multibyte_fix > 0):
multibyte_fix -= 1
continue

any_stop = [s for s in stop_sequences if s in all_text]
if len(any_stop) > 0:
first_stop = any_stop[0]
Expand Down Expand Up @@ -495,7 +510,7 @@ def _create_completion(
"model": self.model_path,
"choices": [
{
"text": text[start:].decode("utf-8"),
"text": text[start:].decode("utf-8", errors="ignore"),
"index": 0,
"logprobs": None,
"finish_reason": None,
Expand All @@ -516,7 +531,7 @@ def _create_completion(
"model": self.model_path,
"choices": [
{
"text": text[returned_characters:].decode("utf-8"),
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
"index": 0,
"logprobs": None,
"finish_reason": finish_reason,
Expand All @@ -525,7 +540,7 @@ def _create_completion(
}
return

text_str = text.decode("utf-8")
text_str = text.decode("utf-8", errors="ignore")

if echo:
text_str = prompt + text_str
Expand All @@ -543,7 +558,7 @@ def _create_completion(

all_tokens = prompt_tokens + completion_tokens
all_token_strs = [
self.detokenize([token]).decode("utf-8") for token in all_tokens
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
]
all_logprobs = [
[Llama.logit_to_logprob(logit) for logit in row]
Expand All @@ -562,7 +577,7 @@ def _create_completion(
)
token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
Expand Down
36 changes: 35 additions & 1 deletion tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,38 @@ def test_llama_pickle():

text = b"Hello World"

assert llama.detokenize(llama.tokenize(text)) == text
assert llama.detokenize(llama.tokenize(text)) == text

def test_utf8(monkeypatch):
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)

## Set up mock function
def mock_eval(*args, **kwargs):
return 0

monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)

output_text = "😀"
output_tokens = llama.tokenize(output_text.encode("utf-8"))
token_eos = llama.token_eos()
n = 0

def mock_sample(*args, **kwargs):
nonlocal n
if n < len(output_tokens):
n += 1
return output_tokens[n - 1]
else:
return token_eos

monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_top_p_top_k", mock_sample)

## Test basic completion with utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=4)
assert completion["choices"][0]["text"] == output_text

## Test basic completion with incomplete utf8 multibyte
n = 0 # reset
completion = llama.create_completion("", max_tokens=1)
assert completion["choices"][0]["text"] == ""