-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
117 lines (95 loc) · 3.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import mlflow
import os
import hydra
from omegaconf import DictConfig, OmegaConf
from omegaconf.listconfig import ListConfig
# This automatically reads in the configuration
@hydra.main(config_name='config')
def go(config: DictConfig):
# Setup the wandb experiment. All runs will be grouped under this name
os.environ["WANDB_PROJECT"] = config["main"]["project_name"]
os.environ["WANDB_RUN_GROUP"] = config["main"]["experiment_name"]
# You can get the path at the root of the MLflow project with this:
root_path = hydra.utils.get_original_cwd()
print(type(config["main"]["execute_steps"]))
# Check which steps we need to execute
if isinstance(config["main"]["execute_steps"], str):
# This was passed on the command line as a comma-separated list of steps
steps_to_execute = config["main"]["execute_steps"].split(",")
else:
assert isinstance(config["main"]["execute_steps"], ListConfig)
steps_to_execute = config["main"]["execute_steps"]
# Download step
if "download" in steps_to_execute:
_ = mlflow.run(
os.path.join(root_path, "download"),
"main",
parameters={
"file_url": config["data"]["file_url"],
"artifact_name": "raw_data.parquet",
"artifact_type": "raw_data",
"artifact_description": "Data as downloaded"
},
)
if "preprocess" in steps_to_execute:
_ = mlflow.run(
os.path.join(root_path, "preprocess"),
"main",
parameters={
"input_artifact": "raw_data.parquet:latest",
"artifact_name": "preprocessed_data.csv",
"artifact_type": "preprocessed_data",
"artifact_description": "Preprocessed Data"
},
)
if "check_data" in steps_to_execute:
_ = mlflow.run(
os.path.join(root_path, "check_data"),
"main",
parameters={
"reference_artifact": config["data"]["reference_dataset"],
"sample_artifact": "preprocessed_data.csv:latest",
"ks_alpha": config["data"]["ks_alpha"]
},
)
if "segregate" in steps_to_execute:
_ = mlflow.run(
os.path.join(root_path, "segregate"),
"main",
parameters={
"input_artifact" : "preprocessed_data.csv:latest",
"artifact_root" : "data",
"artifact_type" : "Train-Test Data",
"test_size" : config["data"]["test_size"],
"random_state" : config["main"]["random_seed"],
"stratify" : config["data"]["stratify"]
},
)
if "random_forest" in steps_to_execute:
# Serialize decision tree configuration
model_config = os.path.abspath("random_forest_config.yml")
with open(model_config, "w+") as fp:
fp.write(OmegaConf.to_yaml(config["random_forest_pipeline"]))
_ = mlflow.run(
os.path.join(root_path, "random_forest"),
"main",
parameters={
"train_data" : "data_train.csv:latest",
"model_config" : model_config,
"export_artifact" : config["random_forest_pipeline"]["export_artifact"],
"random_seed" : config["main"]["random_seed"],
"val_size" : config["data"]["val_size"],
"stratify" : config["data"]["stratify"]
},
)
if "evaluate" in steps_to_execute:
_ = mlflow.run(
os.path.join(root_path, "evaluate"),
"main",
parameters={
"model_export" : config["random_forest_pipeline"]["export_artifact"] +":latest",
"test_data" : "data_test.csv:latest"
},
)
if __name__ == "__main__":
go()