-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
120 lines (83 loc) · 4.28 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
from model.trajairnet import TrajAirNet
from model.utils import ade, fde, TrajectoryDataset, seq_collate
def main():
parser=argparse.ArgumentParser(description='Test TrajAirNet model')
parser.add_argument('--dataset_folder',type=str,default='/dataset/')
parser.add_argument('--dataset_name',type=str,default='7days1')
parser.add_argument('--epoch',type=int,required=True)
parser.add_argument('--obs',type=int,default=11)
parser.add_argument('--preds',type=int,default=120)
parser.add_argument('--preds_step',type=int,default=10)
##Network params
parser.add_argument('--input_channels',type=int,default=3)
parser.add_argument('--tcn_channel_size',type=int,default=256)
parser.add_argument('--tcn_layers',type=int,default=2)
parser.add_argument('--tcn_kernels',type=int,default=4)
parser.add_argument('--num_context_input_c',type=int,default=2)
parser.add_argument('--num_context_output_c',type=int,default=7)
parser.add_argument('--cnn_kernels',type=int,default=2)
parser.add_argument('--gat_heads',type=int, default=16)
parser.add_argument('--graph_hidden',type=int,default=256)
parser.add_argument('--dropout',type=float,default=0.05)
parser.add_argument('--alpha',type=float,default=0.2)
parser.add_argument('--cvae_hidden',type=int,default=128)
parser.add_argument('--cvae_channel_size',type=int,default=128)
parser.add_argument('--cvae_layers',type=int,default=2)
parser.add_argument('--mlp_layer',type=int,default=32)
parser.add_argument('--delim',type=str,default=' ')
parser.add_argument('--model_dir', type=str , default="/saved_models/")
args=parser.parse_args()
##Select device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
##Load data
datapath = os.getcwd() + args.dataset_folder + args.dataset_name + "/processed_data/"
print("Loading Test Data from ",datapath + "test")
dataset_test = TrajectoryDataset(datapath + "test", obs_len=args.obs, pred_len=args.preds, step=args.preds_step, delim=args.delim)
loader_test = DataLoader(dataset_test,batch_size=1,num_workers=4,shuffle=True,collate_fn=seq_collate)
##Load model
model = TrajAirNet(args)
model.to(device)
model_path = os.getcwd() + args.model_dir + "model_" + args.dataset_name + "_" + str(args.epoch) + ".pt"
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
test_ade_loss, test_fde_loss = test(model,loader_test,device)
print("Test ADE Loss: ",test_ade_loss,"Test FDE Loss: ",test_fde_loss)
def test(model,loader_test,device):
tot_ade_loss = 0
tot_fde_loss = 0
tot_batch = 0
for batch in tqdm(loader_test):
tot_batch += 1
batch = [tensor.to(device) for tensor in batch]
obs_traj_all , pred_traj_all, obs_traj_rel_all, pred_traj_rel_all, context, seq_start = batch
num_agents = obs_traj_all.shape[1]
best_ade_loss = float('inf')
best_fde_loss = float('inf')
for i in range(5):
z = torch.randn([1,1 ,128]).to(device)
adj = torch.ones((num_agents,num_agents))
recon_y_all = model.inference(torch.transpose(obs_traj_all,1,2),z,adj,torch.transpose(context,1,2))
ade_loss = 0
fde_loss = 0
for agent in range(num_agents):
obs_traj = np.squeeze(obs_traj_all[:,agent,:].cpu().numpy())
pred_traj = np.squeeze(pred_traj_all[:,agent,:].cpu().numpy())
recon_pred = np.squeeze(recon_y_all[agent].detach().cpu().numpy()).transpose()
ade_loss += ade(recon_pred, pred_traj)
fde_loss += fde((recon_pred), (pred_traj))
ade_total_loss = ade_loss/num_agents
fde_total_loss = fde_loss/num_agents
if ade_total_loss<best_ade_loss:
best_ade_loss = ade_total_loss
best_fde_loss = fde_total_loss
tot_ade_loss += best_ade_loss
tot_fde_loss += best_fde_loss
return tot_ade_loss/(tot_batch),tot_fde_loss/(tot_batch)
if __name__=='__main__':
main()