-
Notifications
You must be signed in to change notification settings - Fork 0
/
exe_pemsbay.py
132 lines (100 loc) · 3.91 KB
/
exe_pemsbay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import argparse
import torch
import datetime
import json
import yaml
import os
from dataset_pemsbay import get_dataloader_original, get_dataloader
from vae_model_pems import VAE_pems
from utils_vae_pems import train_vae, evaluate_vae
from main_model import CSDI_pems
from utils import train, evaluate
parser = argparse.ArgumentParser(description="CSDI")
parser.add_argument("--config", type=str, default="base.yaml")
parser.add_argument('--device', default='cuda:0', help='Device for Attack')
# pm25_validationindex0_20240215_214905
parser.add_argument("--modelfolder", type=str, default="")
parser.add_argument(
"--validationindex", type=int, default=0, help="index of month used for validation (value:[0-7])"
)
parser.add_argument("--nsample", type=int, default=100)
parser.add_argument("--unconditional", action="store_true", default=False)
############################################
parser.add_argument(
"--targetstrategy", type=str, default="random", choices=["mix", "random", "block"]
)
parser.add_argument("--missing_pattern", type=str, default="point") # block|point
#############################################################
args = parser.parse_args()
print(args)
path = "config/" + args.config
with open(path, "r") as f:
config = yaml.safe_load(f)
config["model"]["is_unconditional"] = args.unconditional
config["model"]["target_strategy"] = args.targetstrategy
print(json.dumps(config, indent=4))
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
foldername = (
"./save/pems_validationindex" + str(args.validationindex) + "_" + current_time + "/"
)
print('model folder:', foldername)
os.makedirs(foldername, exist_ok=True)
with open(foldername + "config.json", "w") as f:
json.dump(config, f, indent=4)
train_loader, valid_loader, test_loader, test_train_loader, test_valid_loader, scaler, mean_scaler = get_dataloader_original(
config["train_VAE"]["batch_size"], device=args.device, missing_pattern=args.missing_pattern,
is_interpolate=config["model"]["use_guide"], target_strategy=args.targetstrategy
)
model_vae = VAE_pems(config, args.device).to(args.device)
if __name__ == '__main__':
if args.modelfolder == "":
train_vae(
model_vae,
config["train_VAE"],
train_loader,
valid_loader=valid_loader,
foldername=foldername,
)
else:
model_vae.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth"))
evaluate_vae(
model_vae,
test_train_loader,
test_valid_loader,
test_loader,
scaler=scaler,
mean_scaler=mean_scaler,
foldername=foldername,
)
print('######################## begin diffussioh ######################################')
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
foldername = (
"./save/pems_validationindex" + str(args.validationindex) + "_" + current_time + "/"
)
print('model folder:', foldername)
os.makedirs(foldername, exist_ok=True)
with open(foldername + "config.json", "w") as f:
json.dump(config, f, indent=4)
train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader(
config["train_diffussion"]["batch_size"], device=args.device, missing_pattern=args.missing_pattern,
is_interpolate=config["model"]["use_guide"], target_strategy=args.targetstrategy
)
model = CSDI_pems(config, args.device).to(args.device)
if args.modelfolder == "":
train(
model,
config["train_diffussion"],
train_loader,
valid_loader=valid_loader,
foldername=foldername,
)
else:
model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth"))
evaluate(
model,
test_loader,
nsample=args.nsample,
scaler=scaler,
mean_scaler=mean_scaler,
foldername=foldername,
)