-
Notifications
You must be signed in to change notification settings - Fork 15
/
sample.py
31 lines (21 loc) · 1012 Bytes
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
'''
MIT License
Copyright (c) 2017 Mat Leonard
'''
import argparse
from model import CharRNN, load_model, sample
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('checkpoint', type=str, default=None,
help='initialize network from checkpoint')
parser.add_argument('--gpu', action='store_true', default=False,
help='run the network on the GPU')
parser.add_argument('--num_samples', type=int, default=200,
help='number of samples for generating text')
parser.add_argument('--prime', type=str, default='From afar',
help='prime the network with characters for sampling')
parser.add_argument('--top_k', type=int, default=10,
help='sample from top K character probabilities')
args = parser.parse_args()
net = load_model(args.checkpoint)
print(sample(net, args.num_samples, cuda=args.gpu, top_k=args.top_k, prime=args.prime))