-
Notifications
You must be signed in to change notification settings - Fork 37
/
standard_parser.py
93 lines (77 loc) · 5.35 KB
/
standard_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) 2018 Uber Technologies, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import argparse
DEFAULT_ARCH_CHOICES = ['mnist']
def make_standard_parser(description='No decription provided', arch_choices=DEFAULT_ARCH_CHOICES,
skip_train=False, skip_val=False):
'''Make a standard parser, probably good for many experiments.
Arguments:
description: just used for help
arch_choices: list of strings that may be specified when
selecting architecture type. For example, ('mnist', 'cifar')
would allow selection of different networks for each
dataset. Default architecture is the first in the
list.
skip_train: if True, skip adding a train_h5 arg
skip_val: if True, skip adding a val_h5 arg
'''
parser = argparse.ArgumentParser(description=description,
formatter_class=lambda prog: argparse.ArgumentDefaultsHelpFormatter(prog)
)
# Optimization
parser.add_argument('--opt', type=str, default='sgd', choices=('sgd', 'rmsprop', 'adam'), help='Which optimizer to use')
parser.add_argument('--lr', '-L', type=float, default=.001, help='learning rate')
parser.add_argument('--mom', '-M', type=float, default=.9, help='momentum (only has effect for sgd/rmsprop)')
parser.add_argument('--beta1', type=float, default=.9, help='beta1 for adam opt')
parser.add_argument('--beta2', type=float, default=.99, help='beta2 for adam opt')
parser.add_argument('--adameps', type=float, default=1e-8, help='epsilon for adam opt')
parser.add_argument('--epochs', '-E',type=int, default=5, help='number of epochs.')
# Model
parser.add_argument('--arch', type=str, default=arch_choices[0],
choices=arch_choices, help='Which architecture to use (choices: %s).' % arch_choices)
parser.add_argument('--conv', '-C', action='store_true', help='Use a conv model.')
parser.add_argument('--springprop', '-S', action='store_true', help='Use an springprop model')
parser.add_argument('--springt', '-t', type=float, default=0.5, help='T value to use for springs')
parser.add_argument('--learncoords', '--lc', action='store_true', help='Learn coordinates (update them during training) instead of keeping them fixed.')
parser.add_argument('--l2', type=float, default=0.0, help='L2 regularization to apply to direct parameters.')
parser.add_argument('--l2i', type=float, default=0.0, help='L2 regularization to apply to indirect parameters.')
# Experimental setup
parser.add_argument('--seed', type=int, default=0, help='random number seed for intial params and tf graph')
parser.add_argument('--test', action='store_true', help='Use test data instead of validation data (for final run).')
parser.add_argument('--shuffletrain', '--st', dest='shuffletrain', action='store_true', help='Shuffle training set each epoch.')
parser.add_argument('--noshuffletrain', '--nst', dest='shuffletrain', action='store_false', help='Do not shuffle training set each epoch. Ignore the following "default" value:')
parser.set_defaults(shuffletrain=True)
# Misc
parser.add_argument('--ipy', '-I', action='store_true', help='drop into embedded iPython for debugging.')
parser.add_argument('--nocolor', '--nc', action='store_true', help='Do not use color output (for scripts).')
parser.add_argument('--skipval', action='store_true', help='Skip validation set entirely.')
parser.add_argument('--verbose', '-V', action='store_true', help='Verbose mode (print some extra stuff)')
# Saving a loading
parser.add_argument('--snapshot-to', type=str, default='net', help='Where to snapshot to. --snapshot-to NAME produces NAME_iter.h5 and NAME.json')
parser.add_argument('--snapshot-every', type=int, default=-1, help='Snapshot every N minibatches. 0 to disable snapshots, -1 to snapshot only on last iteration.')
parser.add_argument('--load', type=str, default=None, help='Snapshot to load from: specify as H5_FILE:MISC_FILE.')
parser.add_argument('--output', '-O', type=str, default=None, help='directory output TF results to. If nothing else: skips output.')
# Dataset
if not skip_train:
parser.add_argument('train_h5', type=str, help='Training set hdf5 file.')
if not skip_val:
parser.add_argument('val_h5', type=str, help='Validation set hdf5 file.')
return parser