-
Notifications
You must be signed in to change notification settings - Fork 0
/
ubermain.jl
64 lines (51 loc) · 1.49 KB
/
ubermain.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
using Distributed
using ArgParse
function parse_ubermain()
s = ArgParseSettings()
@add_arg_table s begin
"--procs", "-p"
help = "Number of parallel processes/workers to spawn."
arg_type = Int
default = 1
"--runs", "-r"
help = "Number of runs per experiment setting."
arg_type = Int
default = 5
end
return parse_args(s)
end
# parse number of procs, number of runs
ub_args = parse_ubermain()
# start workers in BPTT env
addprocs(
ub_args["procs"];
exeflags = `--threads=$(Threads.nthreads()) --project=$(Base.active_project())`,
)
# make pkgs available in all processes
@everywhere using BPTT
@everywhere ENV["GKSwstype"] = "nul"
"""
ubermain(n_runs)
Start multiple parallel trainings, with optional grid search and
multiple runs per experiment.
"""
ids = readlines("example_data/LEMON_data/list_all_participants.txt")
function ubermain(n_runs::Int)
# load defaults with correct data types
defaults = parse_args([], argtable())
# list arguments here
args = BPTT.ArgVec([
Argument("experiment", "LEMON"),
Argument("name", "ubermain_test"),
Argument("model", "clippedShallowPLRNN"),
Argument("weak_tf_alpha", 0.1),
Argument("hidden_dim", [10, 30], "H"),
Argument("data_id", ids, "id")
])
# prepare tasks
tasks = prepare_tasks(defaults, args, n_runs)
println(length(tasks))
# run tasks
pmap(main_routine, tasks)
end
ubermain(ub_args["runs"])