-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtest.py
44 lines (32 loc) · 1.08 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
<<<<<<< HEAD
import os
=======
>>>>>>> 2de0f36628fdf3bfccae7e3030ee944e30f514c6
import torch
from torch.utils.data import DataLoader
from dataset import Flickr30dataset
from model import MATnet
from train_model import evaluate
from utils.utils import load_vocabulary, init_net
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--file', type = str, default = "saved/model_0527_a20.pt",
help = "saved model name")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
wordEmbedding = load_vocabulary("data/glove/glove.6B.300d.txt")
test_dset = Flickr30dataset(wordEmbedding, "test")
test_loader = DataLoader(test_dset, batch_size = 32, num_workers = 4, drop_last = True, shuffle = True)
net = MATnet(wordEmbedding)
if torch.cuda.is_available():
print("CUDA available")
net.cuda()
init_net(net, args.file)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
net.eval()
score, _ = evaluate(test_loader, net)
print("evaluation localization score:", score)