-
Notifications
You must be signed in to change notification settings - Fork 141
/
test.py
83 lines (54 loc) · 2.1 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from util.time import *
from util.env import *
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch.nn.functional as F
from util.data import *
from util.preprocess import *
def test(model, dataloader):
# test
loss_func = nn.MSELoss(reduction='mean')
device = get_device()
test_loss_list = []
now = time.time()
test_predicted_list = []
test_ground_list = []
test_labels_list = []
t_test_predicted_list = []
t_test_ground_list = []
t_test_labels_list = []
test_len = len(dataloader)
model.eval()
i = 0
acu_loss = 0
for x, y, labels, edge_index in dataloader:
x, y, labels, edge_index = [item.to(device).float() for item in [x, y, labels, edge_index]]
with torch.no_grad():
predicted = model(x, edge_index).float().to(device)
loss = loss_func(predicted, y)
labels = labels.unsqueeze(1).repeat(1, predicted.shape[1])
if len(t_test_predicted_list) <= 0:
t_test_predicted_list = predicted
t_test_ground_list = y
t_test_labels_list = labels
else:
t_test_predicted_list = torch.cat((t_test_predicted_list, predicted), dim=0)
t_test_ground_list = torch.cat((t_test_ground_list, y), dim=0)
t_test_labels_list = torch.cat((t_test_labels_list, labels), dim=0)
test_loss_list.append(loss.item())
acu_loss += loss.item()
i += 1
if i % 10000 == 1 and i > 1:
print(timeSincePlus(now, i / test_len))
test_predicted_list = t_test_predicted_list.tolist()
test_ground_list = t_test_ground_list.tolist()
test_labels_list = t_test_labels_list.tolist()
avg_loss = sum(test_loss_list)/len(test_loss_list)
return avg_loss, [test_predicted_list, test_ground_list, test_labels_list]