-
Notifications
You must be signed in to change notification settings - Fork 38
/
train_model.py
87 lines (77 loc) · 2.48 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import absolute_import
from __future__ import print_function
import argparse
from util import get_data, get_model, cross_entropy
from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
def train(dataset='mnist', batch_size=128, epochs=50):
"""
Train one model with data augmentation: random padding+cropping and horizontal flip
:param args:
:return:
"""
print('Data set: %s' % dataset)
X_train, Y_train, X_test, Y_test = get_data(dataset)
model = get_model(dataset)
model.compile(
loss='categorical_crossentropy',
optimizer='adadelta',
metrics=['accuracy']
)
# # training without data augmentation
# model.fit(
# X_train, Y_train,
# epochs=epochs,
# batch_size=batch_size,
# shuffle=True,
# verbose=1,
# validation_data=(X_test, Y_test)
# )
# training with data augmentation
# data augmentation
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
model.fit_generator(
datagen.flow(X_train, Y_train, batch_size=batch_size),
steps_per_epoch=len(X_train) / batch_size,
epochs=epochs,
verbose=1,
validation_data=(X_test, Y_test))
model.save('data/model_%s.h5' % dataset)
def main(args):
"""
Train model with data augmentation: random padding+cropping and horizontal flip
:param args:
:return:
"""
assert args.dataset in ['mnist', 'cifar', 'svhn', 'all'], \
"dataset parameter must be either 'mnist', 'cifar', 'svhn' or all"
if args.dataset == 'all':
for dataset in ['mnist', 'cifar', 'svhn']:
train(dataset, args.batch_size, args.epochs)
else:
train(args.dataset, args.batch_size, args.epochs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'-d', '--dataset',
help="Dataset to use; either 'mnist', 'cifar', 'svhn' or 'all'",
required=True, type=str
)
parser.add_argument(
'-e', '--epochs',
help="The number of epochs to train for.",
required=False, type=int
)
parser.add_argument(
'-b', '--batch_size',
help="The batch size to use for training.",
required=False, type=int
)
parser.set_defaults(epochs=120)
parser.set_defaults(batch_size=100)
args = parser.parse_args()
main(args)