forked from kathrinse/TabSurvey
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
139 lines (99 loc) · 4.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import optuna
from models import str2model
from utils.load_data import load_data
from utils.scorer import get_scorer
from utils.timer import Timer
from utils.io_utils import save_results_to_file, save_hyperparameters_to_file, save_loss_to_file
from utils.parser import get_parser, get_given_parameters_parser
from sklearn.model_selection import KFold, StratifiedKFold # , train_test_split
def cross_validation(model, X, y, args, save_model=False):
# Record some statistics and metrics
sc = get_scorer(args)
train_timer = Timer()
test_timer = Timer()
if args.objective == "regression":
kf = KFold(n_splits=args.num_splits, shuffle=args.shuffle, random_state=args.seed)
elif args.objective == "classification" or args.objective == "binary":
kf = StratifiedKFold(n_splits=args.num_splits, shuffle=args.shuffle, random_state=args.seed)
else:
raise NotImplementedError("Objective" + args.objective + "is not yet implemented.")
for i, (train_index, test_index) in enumerate(kf.split(X, y)):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.05, random_state=args.seed)
# Create a new unfitted version of the model
curr_model = model.clone()
# Train model
train_timer.start()
loss_history, val_loss_history = curr_model.fit(X_train, y_train, X_test, y_test) # X_val, y_val)
train_timer.end()
# Test model
test_timer.start()
curr_model.predict(X_test)
test_timer.end()
# Save model weights and the truth/prediction pairs for traceability
curr_model.save_model_and_predictions(y_test, i)
if save_model:
save_loss_to_file(args, loss_history, "loss", extension=i)
save_loss_to_file(args, val_loss_history, "val_loss", extension=i)
# Compute scores on the output
sc.eval(y_test, curr_model.predictions, curr_model.prediction_probabilities)
print(sc.get_results())
# Best run is saved to file
if save_model:
print("Results:", sc.get_results())
print("Train time:", train_timer.get_average_time())
print("Inference time:", test_timer.get_average_time())
# Save the all statistics to a file
save_results_to_file(args, sc.get_results(),
train_timer.get_average_time(), test_timer.get_average_time(),
model.params)
# print("Finished cross validation")
return sc, (train_timer.get_average_time(), test_timer.get_average_time())
class Objective(object):
def __init__(self, args, model_name, X, y):
# Save the model that will be trained
self.model_name = model_name
# Save the trainings data
self.X = X
self.y = y
self.args = args
def __call__(self, trial):
# Define hyperparameters to optimize
trial_params = self.model_name.define_trial_parameters(trial, self.args)
# Create model
model = self.model_name(trial_params, self.args)
# Cross validate the chosen hyperparameters
sc, time = cross_validation(model, self.X, self.y, self.args)
save_hyperparameters_to_file(self.args, trial_params, sc.get_results(), time)
return sc.get_objective_result()
def main(args):
print("Start hyperparameter optimization")
X, y = load_data(args)
model_name = str2model(args.model_name)
study = optuna.create_study(direction=args.direction)
study.optimize(Objective(args, model_name, X, y), n_trials=args.n_trials)
print("Best parameters:", study.best_trial.params)
# Run best trial again and save it!
model = model_name(study.best_trial.params, args)
cross_validation(model, X, y, args, save_model=True)
def main_once(args):
print("Train model with given hyperparameters")
X, y = load_data(args)
model_name = str2model(args.model_name)
parameters = args.parameters[args.dataset][args.model_name]
model = model_name(parameters, args)
sc, time = cross_validation(model, X, y, args)
print(sc.get_results())
print(time)
if __name__ == "__main__":
parser = get_parser()
arguments = parser.parse_args()
print(arguments)
if arguments.optimize_hyperparameters:
main(arguments)
else:
# Also load the best parameters
parser = get_given_parameters_parser()
arguments = parser.parse_args()
main_once(arguments)