-
Notifications
You must be signed in to change notification settings - Fork 9
/
model.lua
40 lines (29 loc) · 818 Bytes
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
require 'nn'
require 'cunn'
require 'cudnn'
require 'dpnn'
require 'optim'
paths.dofile('TripletEmbedding.lua')
if opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
modelAnchor = torch.load(opt.retrain)
else
paths.dofile(opt.modelDef)
modelAnchor = createModel(opt.nGPU)
end
modelPos = modelAnchor:clone('weight', 'bias', 'gradWeight', 'gradBias')
modelNeg = modelAnchor:clone('weight', 'bias', 'gradWeight', 'gradBias')
model = nn.ParallelTable()
model:add(modelAnchor)
model:add(modelPos)
model:add(modelNeg)
alpha = 0.2
criterion = nn.TripletEmbeddingCriterion(alpha)
model = model:cuda()
criterion:cuda()
print('=> Model')
print(model)
print('=> Criterion')
print(criterion)
collectgarbage()