-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_runner.py
82 lines (69 loc) · 2.08 KB
/
train_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import matplotlib.pyplot as plt
from src.lstm.data import INSTANCE_KEYS, LSTMDataset
from src.lstm.models import LSTMNetwork
from src.lstm.trainer import LSTMTrainer
def run():
# data
train_data_src = "./data/training_data.csv"
valid_data_src = "./data/validation_data.csv"
window_size = 48 # 4 hour ->48 5-min intervals
# model architecture
input_dim = len(INSTANCE_KEYS)
hidden_dim = 64
fc_dim = 16
output_dim = 1
attn_layer = False
stateful = True # if True, remember to set shuffle to False and batch_size to value equal to window_size
device = "mps"
# training parameters
batch_size = 48
drop_last = True
learn_rate = 1e-6
epochs = 25
optimizer = "adamw"
loss_fn = "mse"
shuffle = False
quiet = False
model_file_name = "lstm_stateful.pth"
# dataset
train_dataset = LSTMDataset(data_src=train_data_src, window_size=window_size)
valid_dataset = LSTMDataset(data_src=valid_data_src, window_size=window_size)
# model
lstm = LSTMNetwork(
input_dim=input_dim,
hidden_dim=hidden_dim,
fc_dim=fc_dim,
output_dim=output_dim,
attn_layer=attn_layer,
stateful=stateful,
device=device,
)
# trainer
trainer = LSTMTrainer(
model=lstm,
optimizer=optimizer,
loss_fn=loss_fn,
learn_rate=learn_rate,
file_name=model_file_name,
)
# training
train_history, valid_history = trainer.train(
train_dataset=train_dataset,
valid_dataset=valid_dataset,
batch_size=batch_size,
shuffle=shuffle,
epochs=epochs,
drop_last=drop_last,
quiet=quiet,
)
# plot outcome
plt.figure(figsize=(32, 18), dpi=300)
plt.plot(range(len(train_history)), train_history, label="Train Loss")
plt.legend()
plt.savefig("loss_stateful_5.png")
plt.figure(figsize=(32, 18), dpi=300)
plt.plot(range(len(valid_history)), valid_history, label="Valid Loss")
plt.legend()
plt.savefig("valid_loss_stateful_5.png")
if __name__ == "__main__":
run()