Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #249 #255

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions h2oai_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,47 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False,
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs

def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value] and \
hasattr(self.tokenizer, 'model_max_length'):
if hasattr(self.tokenizer, 'model_max_length'):
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
model_max_length = self.tokenizer.model_max_length
verbose = False # FIXME: debug
else:
# unknown
model_max_length = None

verbose = False # FIXME: debug
if model_max_length is not None:
num_prompt_tokens = None
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
# For https://github.com/h2oai/h2ogpt/issues/192
prompt_tokens = self.tokenizer(prompt_text)['input_ids']
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
prompt_text = prompt_text[-model_max_length * chars_per_token:]
for trial in range(0, 3):
prompt_tokens = self.tokenizer(prompt_text)['input_ids']
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
prompt_text = prompt_text[-model_max_length * chars_per_token:]
if verbose:
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
else:
if verbose:
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
break

# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
#
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
# then give room for prompt
fudge = 20
else:
fudge = 0
assert num_prompt_tokens is not None
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
model_max_length - (num_prompt_tokens + fudge)))
if max_new_tokens < generate_kwargs['max_new_tokens']:
if verbose:
print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
prompt_tokens2 = self.tokenizer(prompt_text)['input_ids']
num_prompt_tokens2 = len(prompt_tokens2)
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
print("Reduced max_new_tokens from %s -> %s" % (generate_kwargs['max_new_tokens'], max_new_tokens))
generate_kwargs['max_new_tokens'] = max_new_tokens

data_point = dict(context='', instruction=prompt_text, input='')
if self.prompter is not None:
Expand Down