Skip to content

Commit

Permalink
Merge pull request axolotl-ai-cloud#180 from Glavin001/feat/stream-in…
Browse files Browse the repository at this point in the history
…ference

Add streaming inference & fix stopping at EOS
  • Loading branch information
winglian authored Jun 10, 2023
2 parents e7c9647 + ba8bf40 commit f4c21e6
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import fire
import torch
import yaml
from transformers import GenerationConfig
from transformers import GenerationConfig, TextStreamer

from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
Expand Down Expand Up @@ -64,13 +64,17 @@ def get_multi_line_input() -> Optional[str]:


def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}

for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})

prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)

while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
Expand All @@ -79,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
Expand All @@ -98,10 +102,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))


Expand Down

0 comments on commit f4c21e6

Please sign in to comment.