-
Notifications
You must be signed in to change notification settings - Fork 2
/
config.py
61 lines (56 loc) · 2.07 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import tensorflow as tf
from models import graph_baselines
from models import unet
from models import graph_unet
import train_utils
# defaults for the arguments of the cli
DEFAULTS = {
# default model hyperparameters
# The number of units for the first (and last) U-Net block
'units' : 32,
# The number of output neurons (default 96)
'out_units' : 12*8,
# The number of downsampling blocks and upsampling blocks
'depth' : 3,
# The layer type used in the Graph-UNet model (s. layers/__init__.py)
'layer_type' : 'geo_quadrant_gcn',
'activation' : 'relu',
'use_bias' : True,
# If True, includes a global node in the graph layers (default)
'use_global' : True,
# If True, includes a global node in the graph layers (default)
'output_activation' : None, #'relu',
# default training parameters
'data_dir' : 'data/raw/',
# type of data, either "image" for convolution-based models, or "graph" for graph-based models
'data_type' : 'image',
# group name (used by the wandb logger to group runs)
'group' : 'Baselines',
'batch' : 1,
'epochs' : 15,
'learning_rate' : 'warmup+expDecay', #1e-3,
'optimizer' : 'adam',
'loss' : 'mse',
# as the data is quite large, we use gradient accumulation over the specified number of steps
'acc_gradients_steps' : 16,
'add_temp_encoding' : True,
'add_street_encoding' : False,
'validation_fraction' : 0.1,
'seed_len' : 12,
'target_len' : 12,
'ckpts_dir' : './ckpts',
}
# CLI argument --model
# list of all possible models
MODELS = {
'UNet' : unet.VanillaUNet,
'GraphUNet' : graph_unet.GraphUNet,
}
LOSSES = {
# option for training all models
'mse' : tf.keras.losses.MSE,
'mae' : tf.keras.losses.MAE
}
LEARNING_SCHEDULES = {
'warmup+expDecay' : lambda model: lambda: train_utils.warmupExpDecay(model.global_step)
}