-
Notifications
You must be signed in to change notification settings - Fork 129
/
sample.py
executable file
·115 lines (86 loc) · 3.24 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
import argparse
import cPickle
import traceback
import logging
import time
import sys
import os
import numpy
import codecs
import search
import utils
from dialog_encdec import DialogEncoderDecoder
from numpy_compat import argpartition
from state import prototype_state
logger = logging.getLogger(__name__)
class Timer(object):
def __init__(self):
self.total = 0
def start(self):
self.start_time = time.time()
def finish(self):
self.total += time.time() - self.start_time
def parse_args():
parser = argparse.ArgumentParser("Sample (with beam-search) from the session model")
parser.add_argument("--ignore-unk",
action="store_false",
help="Allows generation procedure to output unknown words (<unk> tokens)")
parser.add_argument("model_prefix",
help="Path to the model prefix (without _model.npz or _state.pkl)")
parser.add_argument("context",
help="File of input contexts")
parser.add_argument("output",
help="Output file")
parser.add_argument("--beam_search",
action="store_true",
help="Use beam search instead of random search")
parser.add_argument("--n-samples",
default="1", type=int,
help="Number of samples")
parser.add_argument("--n-turns",
default=1, type=int,
help="Number of dialog turns to generate")
parser.add_argument("--verbose",
action="store_true", default=False,
help="Be verbose")
parser.add_argument("changes", nargs="?", default="", help="Changes to state")
return parser.parse_args()
def main():
args = parse_args()
state = prototype_state()
state_path = args.model_prefix + "_state.pkl"
model_path = args.model_prefix + "_model.npz"
with open(state_path) as src:
state.update(cPickle.load(src))
logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
model = DialogEncoderDecoder(state)
sampler = search.RandomSampler(model)
if args.beam_search:
sampler = search.BeamSampler(model)
if os.path.isfile(model_path):
logger.debug("Loading previous model")
model.load(model_path)
else:
raise Exception("Must specify a valid model path")
contexts = [[]]
lines = open(args.context, "r").readlines()
if len(lines):
contexts = [x.strip() for x in lines]
print('Sampling started...')
context_samples, context_costs = sampler.sample(contexts,
n_samples=args.n_samples,
n_turns=args.n_turns,
ignore_unk=args.ignore_unk,
verbose=args.verbose)
print('Sampling finished.')
print('Saving to file...')
# Write to output file
output_handle = open(args.output, "w")
for context_sample in context_samples:
print >> output_handle, '\t'.join(context_sample)
output_handle.close()
print('Saving to file finished.')
print('All done!')
if __name__ == "__main__":
main()