-
Notifications
You must be signed in to change notification settings - Fork 2
/
trainScript.py
executable file
·342 lines (287 loc) · 13.3 KB
/
trainScript.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# -*- coding: utf-8 -*-
"""
Created on Wed Dec 11 16:13:28 2019
@author: Arthur
"""
import os
import numpy as np
import mlflow
import os.path
import tempfile
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torch.nn
import torch.nn.functional as F
# These imports are used to create the training datasets
from data.datasets import (DatasetWithTransform, DatasetTransformer,
RawDataFromXrDataset, ConcatDataset_,
Subset_, ComposeTransforms, MultipleTimeIndices)
# Some utils functions
from train.utils import (DEVICE_TYPE, learning_rates_from_string,
run_ids_from_string, list_from_string)
from data.utils import load_training_datasets, load_data_from_run
from testing.utils import create_test_dataset
from testing.metrics import MSEMetric, MaxMetric
from train.base import Trainer
import train.losses
import models.transforms
import argparse
import importlib
import pickle
from data.xrtransforms import SeasonalStdizer
import models.submodels
import sys
import copy
from utils import TaskInfo
from dask.diagnostics import ProgressBar
def negative_int(value: str):
return -int(value)
def check_str_is_None(s: str):
return None if s.lower() == 'none' else s
# PARAMETERS ---------
description = 'Trains a model on a chosen dataset from the store. Allows \
to set training parameters via the CLI.'
parser = argparse.ArgumentParser(description=description)
parser.add_argument('exp_id', type=int,
help='Experiment id of the source dataset containing the '\
'training data.')
parser.add_argument('run_id', type=str,
help='Run id of the source dataset')
parser.add_argument('--batchsize', type=int, default=8)
parser.add_argument('--n_epochs', type=int, default=100)
parser.add_argument('--learning_rate', type=learning_rates_from_string,
default={'0\1e-3'})
parser.add_argument('--train_split', type=float, default=0.8,
help='Between 0 and 1')
parser.add_argument('--test_split', type=float, default=0.8,
help='Between 0 and 1, greater than train_split.')
parser.add_argument('--time_indices', type=negative_int, nargs='*')
parser.add_argument('--printevery', type=int, default=20)
parser.add_argument('--weight_decay', type=float, default=0.05,
help="Depreciated. Controls the weight decay on the linear "
"layer")
parser.add_argument('--model_module_name', type=str, default='models.models1',
help='Name of the module containing the nn model')
parser.add_argument('--model_cls_name', type=str, default='FullyCNN',
help='Name of the class defining the nn model')
parser.add_argument('--loss_cls_name', type=str,
default='HeteroskedasticGaussianLossV2',
help='Name of the loss function used for training.')
parser.add_argument('--transformation_cls_name', type=str,
default='SquareTransform',
help='Name of the transformation applied to outputs ' \
'required to be positive. Should be defined in ' \
'models.transforms.')
parser.add_argument('--submodel', type=str, default='transform1')
parser.add_argument('--features_transform_cls_name', type=str, default='None',
help='Depreciated')
parser.add_argument('--targets_transform_cls_name', type=str, default='None',
help='Depreciated')
params = parser.parse_args()
# Log the experiment_id and run_id of the source dataset
mlflow.log_param('source.experiment_id', params.exp_id)
mlflow.log_param('source.run_id', params.run_id)
# Training parameters
# Note that we use two indices for the train/test split. This is because we
# want to avoid the time correlation to play in our favour during test.
batch_size = params.batchsize
learning_rates = params.learning_rate
weight_decay = params.weight_decay
n_epochs = params.n_epochs
train_split = params.train_split
test_split = params.test_split
model_module_name = params.model_module_name
model_cls_name = params.model_cls_name
loss_cls_name = params.loss_cls_name
transformation_cls_name = params.transformation_cls_name
# Transforms applied to the features and targets
temp = params.features_transform_cls_name
features_transform_cls_name = check_str_is_None(temp)
temp = params.targets_transform_cls_name
targets_transform_cls_name = check_str_is_None(temp)
# Submodel (for instance monthly means)
submodel = params.submodel
# Parameters specific to the input data
# past specifies the indices from the past that are used for prediction
indices = params.time_indices
# Other parameters
print_loss_every = params.printevery
model_name = 'trained_model.pth'
# Directories where temporary data will be saved
data_location = tempfile.mkdtemp(dir='/scratch/ag7531/temp/')
print('Created temporary dir at ', data_location)
figures_directory = 'figures'
models_directory = 'models'
model_output_dir = 'model_output'
def _check_dir(dir_path):
"""Create the directory if it does not already exists"""
if not os.path.exists(dir_path):
os.mkdir(dir_path)
_check_dir(os.path.join(data_location, figures_directory))
_check_dir(os.path.join(data_location, models_directory))
_check_dir(os.path.join(data_location, model_output_dir))
# Device selection. If available we use the GPU.
# TODO Allow CLI argument to select the GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_type = DEVICE_TYPE.GPU if torch.cuda.is_available() \
else DEVICE_TYPE.CPU
print('Selected device type: ', device_type.value)
# FIN PARAMETERS --------------------------------------------------------------
# DATA-------------------------------------------------------------------------
# Extract the run ids for the datasets to use in training
global_ds = load_data_from_run(params.run_id)
# Load data from the store, according to experiment id and run id
xr_datasets = load_training_datasets(global_ds, 'training_subdomains.yaml')
# Split into train and test datasets
datasets, train_datasets, test_datasets = list(), list(), list()
for xr_dataset in xr_datasets:
# TODO this is a temporary fix to implement seasonal patterns
submodel_transform = copy.deepcopy(getattr(models.submodels, submodel))
print(submodel_transform)
xr_dataset = submodel_transform.fit_transform(xr_dataset)
with ProgressBar(), TaskInfo('Computing dataset'):
xr_dataset = xr_dataset.compute()
print(xr_dataset)
dataset = RawDataFromXrDataset(xr_dataset)
dataset.index = 'time'
dataset.add_input('usurf')
dataset.add_input('vsurf')
dataset.add_output('S_x')
dataset.add_output('S_y')
# TODO temporary addition, should be made more general
if submodel == 'transform2':
dataset.add_output('S_x_d')
dataset.add_output('S_y_d')
train_index = int(train_split * len(dataset))
test_index = int(test_split * len(dataset))
features_transform = ComposeTransforms()
targets_transform = ComposeTransforms()
transform = DatasetTransformer(features_transform, targets_transform)
dataset = DatasetWithTransform(dataset, transform)
# dataset = MultipleTimeIndices(dataset)
# dataset.time_indices = [0, ]
train_dataset = Subset_(dataset, np.arange(train_index))
test_dataset = Subset_(dataset, np.arange(test_index, len(dataset)))
train_datasets.append(train_dataset)
test_datasets.append(test_dataset)
datasets.append(dataset)
# Concatenate datasets. This adds shape transforms to ensure that all regions
# produce fields of the same shape, hence should be called after saving
# the transformation so that when we're going to test on another region
# this does not occur.
train_dataset = ConcatDataset_(train_datasets)
test_dataset = ConcatDataset_(test_datasets)
# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, drop_last=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, drop_last=True)
print('Size of training data: {}'.format(len(train_dataset)))
print('Size of validation data : {}'.format(len(test_dataset)))
# FIN DATA---------------------------------------------------------------------
# NEURAL NETWORK---------------------------------------------------------------
# Load the loss class required in the script parameters
n_target_channels = datasets[0].n_targets
criterion = getattr(train.losses, loss_cls_name)(n_target_channels)
# Recover the model's class, based on the corresponding CLI parameters
try:
models_module = importlib.import_module(model_module_name)
model_cls = getattr(models_module, model_cls_name)
except ModuleNotFoundError as e:
raise type(e)('Could not find the specified module for : ' +
str(e))
except AttributeError as e:
raise type(e)('Could not find the specified model class: ' +
str(e))
net = model_cls(datasets[0].n_features, criterion.n_required_channels)
try:
transformation_cls = getattr(models.transforms, transformation_cls_name)
transformation = transformation_cls()
transformation.indices = criterion.precision_indices
net.final_transformation = transformation
except AttributeError as e:
raise type(e)('Could not find the specified transformation class: ' +
str(e))
print('--------------------')
print(net)
print('--------------------')
print('***')
# Log the text representation of the net into a txt artifact
with open(os.path.join(data_location, models_directory,
'nn_architecture.txt'), 'w') as f:
print('Writing neural net architecture into txt file.')
f.write(str(net))
# FIN NEURAL NETWORK ---------------------------------------------------------
# Add transforms required by the model.
for dataset in datasets:
dataset.add_transforms_from_model(net)
# Training---------------------------------------------------------------------
# Adam optimizer
# To GPU
net.to(device)
# Optimizer and learning rate scheduler
params = list(net.parameters())
optimizer = optim.Adam(params, lr=learning_rates[0], weight_decay=weight_decay)
lr_scheduler = MultiStepLR(optimizer, list(learning_rates.keys())[1:],
gamma=0.1)
trainer = Trainer(net, device)
trainer.criterion = criterion
trainer.print_loss_every = print_loss_every
# metrics saved independently of the training criterion.
metrics = {'R2': MSEMetric(), 'Inf Norm': MaxMetric()}
for metric_name, metric in metrics.items():
metric.inv_transform = lambda x: test_dataset.inverse_transform_target(x)
trainer.register_metric(metric_name, metric)
for i_epoch in range(n_epochs):
print('Epoch number {}.'.format(i_epoch))
# TODO remove clipping?
train_loss = trainer.train_for_one_epoch(train_dataloader, optimizer,
lr_scheduler, clip=1.)
test = trainer.test(test_dataloader)
if test == 'EARLY_STOPPING':
print(test)
break
test_loss, metrics_results = test
# Log the training loss
print('Train loss for this epoch is ', train_loss)
print('Test loss for this epoch is ', test_loss)
for metric_name, metric_value in metrics_results.items():
print('Test {} for this epoch is {}'.format(metric_name, metric_value))
mlflow.log_metric('train loss', train_loss, i_epoch)
mlflow.log_metric('test loss', test_loss, i_epoch)
mlflow.log_metrics(metrics_results)
# Update the logged number of actual training epochs
mlflow.log_param('n_epochs_actual', i_epoch + 1)
# FIN TRAINING ----------------------------------------------------------------
# Save the trained model to disk
net.cpu()
full_path = os.path.join(data_location, models_directory, model_name)
torch.save(net.state_dict(), full_path)
net.cuda(device)
# Save other parts of the model
# TODO this should not be necessary
print('Saving other parts of the model')
full_path = os.path.join(data_location, models_directory, 'transformation')
with open(full_path, 'wb') as f:
pickle.dump(transformation, f)
with TaskInfo('Saving trained model'):
mlflow.log_artifact(os.path.join(data_location, models_directory))
# DEBUT TEST ------------------------------------------------------------------
for i_dataset, dataset, test_dataset, xr_dataset in zip(range(len(datasets)),
datasets,
test_datasets,
xr_datasets):
test_dataloader = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, drop_last=True)
output_dataset = create_test_dataset(net, criterion.n_required_channels,
xr_dataset, test_dataset,
test_dataloader, test_index, device)
# Save model output on the test dataset
output_dataset.to_zarr(os.path.join(data_location, model_output_dir,
f'test_output{i_dataset}'))
# Log artifacts
print('Logging artifacts...')
mlflow.log_artifact(os.path.join(data_location, figures_directory))
mlflow.log_artifact(os.path.join(data_location, model_output_dir))
print('Done...')