Skip to content

Commit

Permalink
Prevent returning partial stop string (#1392)
Browse files Browse the repository at this point in the history
  • Loading branch information
mingfang authored May 22, 2023
1 parent 621bc89 commit 75d8ab2
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def prepare_logits_processor(
return processor_list


def partial_stop(output, stop_str):
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False


@torch.inference_mode()
def generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2
Expand Down Expand Up @@ -160,31 +167,41 @@ def generate_stream(
skip_special_tokens=True,
spaces_between_special_tokens=False,
)

partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
else:
partially_stopped = partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
break
else:
partially_stopped = partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")

yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}

# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}

if stopped:
break
Expand Down

0 comments on commit 75d8ab2

Please sign in to comment.