-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_ssed.py
68 lines (41 loc) · 2.2 KB
/
train_ssed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import torch
import yaml
import os
from Data_Preparation.data_prepare_ssed import prepare_data
from DDPM import DDPM
from denoising_model_seed import DualBranchDenoisingModel
from utils import train
from torch.utils.data import DataLoader, Subset, TensorDataset
from sklearn.model_selection import train_test_split
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="base.yaml")
parser.add_argument('--device', default='cuda:0', help='Device')
parser.add_argument('--n_type', type=str, default='EOG', help='noise version')
args = parser.parse_args()
print(args)
path = "config/" + args.config
with open(path, "r") as f:
config = yaml.safe_load(f)
foldername = "./check_points/noise_type_" + args.n_type + "/"
print('folder:', foldername)
os.makedirs(foldername, exist_ok=True)
[X_train, y_train] = prepare_data(r'data/ssed_noise.npy', r'data/ssed_eeg.npy')
X_train = torch.FloatTensor(X_train).unsqueeze(dim=1)
y_train = torch.FloatTensor(y_train).unsqueeze(dim=1)
print(X_train.shape)
train_val_set = TensorDataset(y_train, X_train)
train_idx, val_test_idx = train_test_split(list(range(len(train_val_set))), test_size=0.2, random_state=666)
test_idx, val_idx = train_test_split(list(range(len(val_test_idx))), test_size=0.5, random_state=666)
train_set = Subset(train_val_set, train_idx)
val_set = Subset(train_val_set, val_idx)
test_set = Subset(train_val_set, test_idx)
train_loader = DataLoader(train_set, batch_size=config['train']['batch_size'],
shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_set, batch_size=config['train']['batch_size'], drop_last=True, num_workers=8)
test_loader = DataLoader(test_set, batch_size=64, num_workers=8)
base_model = DualBranchDenoisingModel(config['train']['feats']).to(args.device)
model = DDPM(base_model, config, args.device)
train(model, config['train'], train_loader, args.device,
valid_loader=val_loader, valid_epoch_interval=10, foldername=foldername)