Skip to content

Commit

Permalink
Merge pull request #199 from NanoCode012/chore/prompter-arg
Browse files Browse the repository at this point in the history
chore: Refactor inf_kwargs out
  • Loading branch information
NanoCode012 authored Jun 13, 2023
2 parents aaadacf + dc77c8e commit 068fc48
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
return instruction


def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}

for token, symbol in default_tokens.items():
Expand Down Expand Up @@ -257,13 +257,13 @@ def train(

if cfg.inference:
logging.info("calling do_inference function")
inf_kwargs: Dict[str, Any] = {}
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
inf_kwargs["prompter"] = None
prompter = None
else:
inf_kwargs["prompter"] = kwargs["prompter"]
do_inference(cfg, model, tokenizer, **inf_kwargs)
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
return

if "shard" in kwargs:
Expand Down

0 comments on commit 068fc48

Please sign in to comment.