-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·129 lines (110 loc) · 4.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
from __future__ import print_function
import pickle
import json
import numpy as np
import glob
import os
import neukrill_net.utils as utils
import neukrill_net.image_processing as image_processing
import neukrill_net.augment as augment
import sklearn.preprocessing
import sklearn.ensemble
import sklearn.linear_model
import sklearn.cross_validation
import sklearn.dummy
from sklearn.externals import joblib
import sklearn.metrics
import sklearn.pipeline
import argparse
def main(run_settings_path, verbose=False, force=False):
# load the non-run-specific settings
settings = utils.Settings('settings.json')
# load the run-specific settings
run_settings = utils.load_run_settings(run_settings_path,
settings,
settings_path='settings.json', force=force)
if run_settings['model type'] == 'sklearn':
train_sklearn(run_settings, verbose=verbose, force=force)
elif run_settings['model type'] == 'pylearn2':
train_pylearn2(run_settings, verbose=verbose, force=force)
else:
raise NotImplementedError("Unsupported model type.")
def train_sklearn(run_settings, verbose=False, force=False):
# unpack settings
settings = run_settings['settings']
# get all training file paths and class names
image_fname_dict = settings.image_fnames
# now being parsed from json
augment_settings = run_settings["preprocessing"]
# build processing function
processing = augment.augmentation_wrapper(**augment_settings)
# load data as design matrix, applying processing function
X, y = utils.load_data(image_fname_dict, classes=settings.classes,
processing=processing, verbose=verbose)
# make a label encoder and encode the labels
label_encoder = sklearn.preprocessing.LabelEncoder()
y = label_encoder.fit_transform(y)
if run_settings['classifier'] == 'dummy':
# just a dummy uniform probability classifier for working purposes
clf = sklearn.dummy.DummyClassifier(strategy='uniform')
elif run_settings['classifier'] == 'logistic regression':
clf = sklearn.linear_model.SGDClassifier(n_jobs=-1,
loss='log')
elif run_settings['classifier'] == 'random forest':
forest = sklearn.ensemble.RandomForestClassifier(n_jobs=-1,
n_estimators=100,
# verbose=1,
max_depth=5)
scaler = sklearn.preprocessing.StandardScaler()
clf = sklearn.pipeline.Pipeline((("scl",scaler),("clf",forest)))
# only supporting stratified shuffle split for now
cv = sklearn.cross_validation.StratifiedShuffleSplit(y,
**run_settings['cross validation'])
results = []
for train, test in cv:
clf.fit(X[train], y[train])
p = clf.predict_proba(X[test])
results.append(sklearn.metrics.log_loss(y[test], p))
print("Average CV: {0} +/- {1}".format(np.mean(results),
np.sqrt(np.var(results))))
# save the model in the data directory, in a "models" subdirectory
# with the name of the run_settings as the name of the pkl
joblib.dump(clf, run_settings["pickle abspath"], compress=3)
# store the raw log loss results back in the run settings json
run_settings["crossval results"] = results
# along with the other things we've added
utils.save_run_settings(run_settings)
def train_pylearn2(run_settings, verbose=False, force=False):
"""
Function to call operations for running a pylearn2 model using
the settings found in run_settings.
"""
import pylearn2.config
# unpack settings
settings = run_settings['settings']
# format the YAML file
yaml_string = utils.format_yaml(run_settings, settings)
# save json file
utils.save_run_settings(run_settings)
# then we load the yaml file using pylearn2
train = pylearn2.config.yaml_parse.load(yaml_string)
# and run the model!
train.main_loop()
import pdb
pdb.set_trace()
if __name__=='__main__':
# need to argparse for run settings path
parser = argparse.ArgumentParser(description='Train a model and store a'
'pickled model file.')
# nargs='?' will look for a single argument but failover to default
parser.add_argument('run_settings', metavar='run_settings', type=str,
nargs='?', default=os.path.join("run_settings","default.json"),
help="Path to run settings json file.")
# add force option
parser.add_argument('-f', action="store_true", help="Force overwrite of"
" model files/submission csvs/anything else.")
# add verbose option
parser.add_argument('-v', action="store_true", help="Run verbose.")
args = parser.parse_args()
main(args.run_settings,verbose=args.v,force=args.f)