-
Notifications
You must be signed in to change notification settings - Fork 0
/
supervised.py
executable file
·41 lines (30 loc) · 1.07 KB
/
supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/env python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tetris import TetrisNN
def main():
epochs = 50000
model = TetrisNN(44).to('mps')
# load weights...
model.load_state_dict(torch.load("supervised.pt"))
model.eval()
optimizer = optim.AdamW(model.parameters(), lr=1e-5, amsgrad=True)
dataset = np.load('dataset.npy')
X = torch.tensor(np.reshape(dataset[:, :20], newshape=(len(dataset), 1, 20)), device='mps', dtype=torch.float32)
y = torch.tensor(dataset[:, 20:], device='mps', dtype=torch.float32)
for epoch in range(epochs):
print('ITERATION = %d/%d' %(epoch + 1, epochs))
predictions = model(X)
criterion = nn.SmoothL1Loss()
loss = criterion(predictions, y)
# Optimize the model
optimizer.zero_grad()
loss.backward()
# In-place gradient clipping
torch.nn.utils.clip_grad_value_(model.parameters(), 100)
optimizer.step()
torch.save(model.state_dict(), "supervised.pt")
if __name__ == '__main__':
main()