-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
38 lines (27 loc) · 1.1 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
from os import path
from neural_dialogue_model.model_args import Args
from neural_dialogue_model.models import NeuralDialogueModel
def create_parser():
parser = argparse.ArgumentParser(description='')
group = parser.add_argument_group("Dialogues")
group.add_argument('--model', type=path.abspath, metavar="FP", help="Path to model parameters")
group.add_argument('--spm', type=path.abspath, metavar="FP", help="Path to sentencepiece model")
group.add_argument('--vocab', type=path.abspath, metavar="FP", help="Path to vocab")
return parser
def main():
parser = create_parser()
parser_args = parser.parse_args()
args = Args(model_path=parser_args.model, spm_path=parser_args.spm, vocab_path=parser_args.vocab)
model = NeuralDialogueModel(args)
contexts = []
while True:
utterance = input("input: ")
if utterance == "q":
break
contexts.append(utterance)
responses = model(contexts)
print("output:\n- " + "\n- ".join(responses))
contexts.append(responses[0])
if __name__ == "__main__":
main()