-
Notifications
You must be signed in to change notification settings - Fork 0
/
parsing.jl
342 lines (278 loc) · 11.1 KB
/
parsing.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
using ArgParse
using Flux
function initialize_model(args::AbstractDict, D::AbstractDataset; mod = @__MODULE__)
# gather args
M = args["latent_dim"]
id_tf = args["observation_model"] == "Identity"
B = args["num_bases"]
Layers = args["hidden_layers"]
model_name = args["model"]
hidden_dim = args["hidden_dim"]
# model type in correct module scope
model_t = @eval mod $(Symbol(model_name))
# specify model args based on model type
if model_t <: AbstractVanillaPLRNN
model_args = (M,)
elseif model_t <: AbstractDendriticPLRNN
model_args = id_tf ? (M, B, D.X) : (M, B)
elseif model_t <: AbstractDeepPLRNN
model_args = (M, Layers)
elseif model_t <: AbstractShallowPLRNN
model_args = (M, hidden_dim)
end
# external inputs?
K = !isempty(args["path_to_inputs"]) ? (size(D.S, 2),) : ()
# optional arguments
opt_args = args["optional_model_args"]
# initialize model
model = model_t(model_args..., opt_args..., K...)
println("Model / # Parameters: $(typeof(model)) / $(num_params(model))")
return model
end
function initialize_observation_model(args::AbstractDict, D::AbstractDataset)
N = size(D.X, 2)
M = args["latent_dim"]
if !isdefined(D,:R) #dataset has no nuisance artifacts (which would be relevant for the observation model)
# initialize by default w/o bias
if args["observation_model"] == "Affine"
obs_model = Affine(N, M; use_bias = false)
elseif args["observation_model"] == "Identity"
obs_model = Identity(N, M)
elseif args["observation_model"] == "Regressor"
error("No nuisance artifacts given: Regressor observation model cannot be used.
Choose a different observation model.")
end
else
P = size(D.R, 2) #dimension of the nuisance artifacts
if args["observation_model"] == "Regressor"
obs_model = Regressor(N, M, P)
else
error("Nuisance artifacts given: This observation model does not take them into account.
Choose a different observation model.")
end
end
println("Obs. Model / # Parameters: $(typeof(obs_model)) / $(num_params(obs_model))")
return obs_model
end
function initialize_optimizer(args::Dict{String, Any})
# optimizer chain
opt_vec = []
# vars
κ = args["gradient_clipping_norm"]::Float32
ηₛ = args["start_lr"]::Float32
ηₑ = args["end_lr"]::Float32
E = args["epochs"]::Int
bpe = args["batches_per_epoch"]::Int
# set gradient clipping
if κ > zero(κ)
push!(opt_vec, ClipNorm(κ))
end
# set SGD optimzier (ADAM, RAdam, etc)
opt_sym = Symbol(args["optimizer"])
opt = @eval $opt_sym($ηₛ)
push!(opt_vec, opt)
# set exponential decay learning rate scheduler
γ = exp(log(ηₑ / ηₛ) / E)
decay = ExpDecay(1, γ, bpe, ηₑ, 1)
push!(opt_vec, decay)
return Flux.Optimise.Optimiser(opt_vec...)
end
get_device(args::AbstractDict) =
if args["device"] == "gpu"
return gpu
else
return cpu
end
"""
argtable()
Prepare the argument table holding the information of all possible arguments
and correct datatypes.
"""
function argtable()
settings = ArgParseSettings()
defaults = load_defaults()
@add_arg_table settings begin
# meta
"--experiment"
help = "The overall experiment name."
arg_type = String
default = defaults["experiment"] |> String
"--name"
help = "Name of a single experiment instance."
arg_type = String
default = defaults["name"] |> String
"--run", "-r"
help = "The run ID."
arg_type = Int
default = defaults["run"] |> Int
"--scalar_saving_interval"
help = "The interval at which scalar quantities are stored measured in epochs."
arg_type = Int
default = defaults["scalar_saving_interval"] |> Int
"--image_saving_interval"
help = "The interval at which images are stored measured in epochs."
arg_type = Int
default = defaults["image_saving_interval"] |> Int
# data
"--path_to_data", "-d"
help = "Path to (folder of) dataset used for training."
arg_type = String
default = defaults["path_to_data"] |> String
"--path_to_inputs"
help = "Path to (folder of) external inputs used for training."
arg_type = String
default = defaults["path_to_inputs"] |> String
"--path_to_artifacts"
help = "Path to (folder of) noise artifacts used for training."
arg_type = String
default = defaults["path_to_artifacts"] |> String
"--data_id"
help = "ID identifying the data and its inputs/artifacts.
Can be left empty if one writes the full paths and not folder beforehand. "
arg_type = String
default = defaults["data_id"] |> String
"--train_test_split"
help = "Marks the endpoint of the trainset of the dataset as Integer or as ratio of the total dataset length as Float.
If the integer is equal to the length of the total dataset there is no testset. Same applies for choosing 1."
arg_type = Float32
default = defaults["train_test_split"] |> Float32
# training
"--teacher_forcing_interval"
help = "The teacher forcing interval to use."
arg_type = Int
default = defaults["teacher_forcing_interval"] |> Int
"--weak_tf_alpha"
help = "α used for weak TF."
arg_type = Float32
default = defaults["weak_tf_alpha"] |> Float32
"--gaussian_noise_level"
help = "Noise level of gaussian noise added to teacher signals."
arg_type = Float32
default = defaults["gaussian_noise_level"] |> Float32
"--sequence_length", "-T"
help = "Length of sequences sampled from the dataset during training."
arg_type = Int
default = defaults["sequence_length"] |> Int
"--batch_size", "-S"
help = "The number of sequences to pack into one batch."
arg_type = Int
default = defaults["batch_size"] |> Int
"--epochs", "-e"
help = "The number of epochs to train for."
arg_type = Int
default = defaults["epochs"] |> Int
"--batches_per_epoch"
help = "The number of batches processed in each epoch."
arg_type = Int
default = defaults["batches_per_epoch"] |> Int
"--gradient_clipping_norm"
help = "The norm at which to clip gradients during training."
arg_type = Float32
default = defaults["gradient_clipping_norm"] |> Float32
"--optimizer"
help = "The optimizer to use for SGD optimization. Must be one provided by Flux.jl."
arg_type = String
default = defaults["optimizer"] |> String
"--start_lr"
help = "Learning rate passed to the optimizer at the beginning of training."
arg_type = Float32
default = defaults["start_lr"] |> Float32
"--end_lr"
help = "Target learning rate at the end of training due to exponential decay."
arg_type = Float32
default = defaults["end_lr"] |> Float32
"--device"
help = "Training device to use."
arg_type = String
default = defaults["device"] |> String
# model
"--model", "-m"
help = "RNN to use."
arg_type = String
default = defaults["model"] |> String
"--hidden_layers"
help = "RNN MLP hidden layer dimensions."
arg_type = String
default = defaults["hidden_layers"] |> String
"--latent_dim", "-M"
help = "RNN latent dimension."
arg_type = Int
default = defaults["latent_dim"] |> Int
"--num_bases", "-B"
help = "Number of bases to use in dendritic PLRNN"
arg_type = Int
default = defaults["num_bases"] |> Int
"--observation_model", "-o"
help = "Observation model to use."
arg_type = String
default = defaults["observation_model"] |> String
# Manifold Attractor Regularization
"--MAR_ratio"
help = "Ratio of regularized states."
arg_type = Float32
default = defaults["MAR_ratio"] |> Float32
"--MAR_lambda"
help = "Regularization factor λ."
arg_type = Float32
default = defaults["MAR_lambda"] |> Float32
"--lat_model_regularization"
help = "Regularization λ for latent model parameters."
arg_type = Float32
default = defaults["lat_model_regularization"] |> Float32
"--obs_model_regularization"
help = "Regularization λ for observation model parameters."
arg_type = Float32
default = defaults["obs_model_regularization"] |> Float32
# Metrics
"--D_stsp_scaling"
help = "GMM scaling parameter."
arg_type = Float32
default = defaults["D_stsp_scaling"] |> Float32
"--D_stsp_bins"
help = "Number of bins for D_stsp binning method."
arg_type = Int
default = defaults["D_stsp_bins"] |> Int
"--PSE_smoothing"
help = "Gaussian kernel smoothing σ for power spectrum smoothing."
arg_type = Float32
default = defaults["PSE_smoothing"] |> Float32
"--PE_n"
help = "n-step ahead prediction error."
arg_type = Int
default = defaults["PE_n"] |> Int
"--hidden_dim"
help = "hidden dimension for shallow PLRNN"
arg_type = Int
default = defaults["hidden_dim"] |> Int
"--optional_model_args"
help = "Optional model arguments."
arg_type = Vector{String}
default = defaults["optional_model_args"] |> Vector{String}
#fMRI specifics
"--TR"
help = "Set to 0 to ignore. The time resolution of the fMRI timeseries, i.e. the time between two images (in seconds)"
arg_type = Float32
default = defaults["TR"] |> Float32
"--cut_l"
help = "Set to 0 to ignore. How much in terms of length of the hrf is cropped from the deconvoluded data at the beginning.
Choose Float between 0 and 1 to interpret it as ratio. If bigger than 1 it will be interpreted as integer"
arg_type = Float32
default = defaults["cut_l"] |> Float32
"--cut_r"
help = "Set to 0 to ignore. How much in terms of length of the hrf is cropped from the deconvoluded data at the end.
Choose Float between 0 and 1 to interpret it as ratio. If bigger than 1 it will be interpreted as integer"
arg_type = Float32
default = defaults["cut_r"] |> Float32
"--min_conv_noise"
help = "If the noise level estimated for the Wiener deconvolution is smaller than the min_conv_noise,
the min_conv_noise is set to be the noise level. I.e. choosing 0 has no influence."
arg_type = Float32
default = defaults["min_conv_noise"] |> Float32
end
return settings
end
"""
parse_commandline()
Parses all commandline arguments for execution of `main.jl`.
"""
parse_commandline() = parse_args(argtable())