-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathmain_test.lua
44 lines (35 loc) · 1.22 KB
/
main_test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
-- Require the detection package
require 'detection'
-- Paths
local dataset_name = config.dataset
local image_set = config.test_img_set
local dataset_dir = paths.concat(config.dataset_path,dataset_name)
local ss_dir = './data/datasets/selective_search_data/'
local ss_file = paths.concat(ss_dir,dataset_name .. '_' .. image_set .. '.mat')
local param_path = config.model_weights
local model_path = config.model_def
-- Loading the dataset
local dataset
local model_opt = {}
if config.dataset == 'MSCOCO' then
print('MSCOCO '.. image_set)
dataset = detection.DataSetCoco({image_set = image_set, datadir = dataset_dir, test_mode = false})
model_opt.nclass = 80
else
print('VOC '.. image_set)
local year = 2007
if config.dataset:find(2012) then
year = 2012
end
dataset = detection.DataSetPascal({image_set = image_set, datadir = dataset_dir, roidbdir = ss_dir , roidbfile = ss_file, year = year})
model_opt.nclass = 20
end
-- Creating the detection net
model_opt.test = true
model_opt.fine_tunning = false
network = detection.Net(model_path,param_path, model_opt)
-- Creating the wrapper
local network_wrapper = detection.NetworkWrapper()
-- Test the network
print('Testing the network...')
network_wrapper:testNetwork(dataset)