-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinput.py
46 lines (37 loc) · 1.07 KB
/
input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from params import fp_params, data_params, procedure_params, logging_config
import sys
from src import (
lc,
multiprocessing_logging,
log_inputs,
construct_fingerprint,
prepare_data,
run_single_model,
run_multiprocessing,
single_model_mp,
)
# partition given train/test indices
p = int(sys.argv[1])
# configure logger
lc.dictConfig(logging_config)
multiprocessing_logging.install_mp_handler()
# log inputs
log_inputs()
if __name__ == "__main__":
# construct fingerprint
if procedure_params["construct_fp"]:
construct_fingerprint(**fp_params)
# machine learning
if (
procedure_params["train"]
or procedure_params["predict"]
or procedure_params["hyperparametrize"]
):
# data
X, y = prepare_data(**data_params)
# single model
if not procedure_params["multiprocessing"]:
run_single_model(X=X, y=y, p=p)
# several models across processors
if procedure_params["multiprocessing"]:
run_multiprocessing(single_model=single_model_mp, X=X, y=y)