-
Notifications
You must be signed in to change notification settings - Fork 0
/
extractOther.lua
96 lines (83 loc) · 2.98 KB
/
extractOther.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
require 'cudnn'
require 'cunn'
require 'image'
require 'mattorch'
require 'nn'
require 'stn'
require 'trepl'
require 'xlua'
local pl = require('pl.import_into')()
local cmd = torch.CmdLine()
cmd:option('-dataName', '')
cmd:option('-imageNum', 6)
cmd:option('-batchSize', 128)
cmd:option('-modelDir', './model/')
cmd:option('-networkType', '') -- necessary
local option = cmd:parse(arg)
assert(option.networkType ~= '', ' you should specify the network type')
local dataName = option.dataName
local imageNum = option.imageNum
local batchSize = option.batchSize
local modelDir = option.modelDir
local networkType = option.networkType
local descriptorDim = 128
local model
local inputPatchSize
if networkType == 'DeepDesc_a' then
model = 'DeepDesc_all.t7'
inputPatchSize = 64
elseif networkType == 'DeepDesc_ly' then
model = 'DeepDesc_liberty+yosemite.t7'
inputPatchSize = 64
elseif networkType == 'PNNet' then
model = 'PNNet_liberty.t7'
inputPatchSize = 32
elseif networkType == 'TFeat_R' then
model = 'TFeat_RatioS_liberty.t7'
inputPatchSize = 32
elseif networkType == 'TFeat_M' then
model = 'TFeat_MarginS_liberty.t7'
inputPatchSize = 32
else
print(' the model type hasnt been defined')
os.exit()
end
local network = torch.load(paths.concat(modelDir, model)):cuda()
print('use '..networkType..' to extract the feature of '..dataName)
for image = 1,imageNum do
print(' image: '..image)
-------------------------------------------------------------------------
--Load patch-------------------------------------------------------------
local fileContent = mattorch.load(paths.concat('data', dataName, 'patch', image, 'R_64_patch.mat'))
--Due to the mattorch format, the loaded matrix should be transposed
--Remember to use "clone()" inside the for loop
local frame = fileContent.frame:clone()
local patch
if inputPatchSize == 32 then
patch = fileContent.local_norm_patch_32:clone()
elseif inputPatchSize == 64 then
patch = fileContent.local_norm_patch_64:clone()
else
print(' you should first extract patch with suitable size!')
os.exit()
end
local patchNum = patch:size(1)
for k = 1,patchNum do
local tmp = patch[{k, 1, {}, {}}]:clone()
patch[{k, 1, {}, {}}] = tmp:t():clone()
end
patch = patch:float() -- This is an important step since net:forward can only input FloatTensor
-------------------------------------------------------------------------
--Feed into network------------------------------------------------------
local descriptor = torch.Tensor(patchNum, descriptorDim)
local descriptorSplit = descriptor:split(batchSize)
for i,v in ipairs(patch:split(batchSize)) do
v = v:cuda()
descriptorSplit[i]:copy(network:forward(v))
end
-------------------------------------------------------------------------
--Save the output--------------------------------------------------------
local outputContent = {frame = frame, descriptor = descriptor}
mattorch.save(paths.concat('data', dataName, 'patch', image, 'R_64_'..networkType..'.mat'), outputContent)
collectgarbage()
end