-
Notifications
You must be signed in to change notification settings - Fork 32
/
read_patch.lua
95 lines (67 loc) · 1.87 KB
/
read_patch.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
require 'mobdebug'.start()
require 'nn'
require 'nngraph'
require 'optim'
require 'image'
local model_utils=require 'model_utils'
local mnist = require 'mnist'
nngraph.setDebug(true)
N = 12
A = 28
h_dec_n = 100
n_data = 2
x = nn.Identity()()
gx_raw = nn.Identity()()
gy_raw = nn.Identity()()
sigma_raw = nn.Identity()()
delta_raw = nn.Identity()()
delta = nn.Exp()(delta_raw)
sigma = nn.Exp()(sigma_raw)
sigma = nn.Power(-2)(sigma)
sigma = nn.MulConstant(-1/2)(sigma)
gx = nn.AddConstant(1)(gx_raw)
gy = nn.AddConstant(1)(gy_raw)
gx = nn.MulConstant((A + 1) / 2)(gx)
gy = nn.MulConstant((A + 1) / 2)(gy)
delta = nn.MulConstant((math.max(A,A)-1)/(N-1))(delta)
ascending = nn.Identity()()
function genr_filters(g)
filters = {}
for i = 1, N do
mu_i = nn.CAddTable()({g, nn.MulConstant(i - N/2 - 1/2)(delta)})
mu_i = nn.MulConstant(-1)(mu_i)
d_i = nn.CAddTable()({mu_i, ascending})
d_i = nn.Power(2)(d_i)
exp_i = nn.CMulTable()({d_i, sigma})
exp_i = nn.Exp()(exp_i)
exp_i = nn.View(n_data, 1, A)(exp_i)
filters[#filters + 1] = exp_i
end
filterbank = nn.JoinTable(2)(filters)
return filterbank
end
filterbank_x = genr_filters(gx)
filterbank_y = genr_filters(gy)
patch = nn.MM()({filterbank_x, x})
patch = nn.MM(false, true)({patch, filterbank_y})
m = nn.gModule({x, gx_raw, gy_raw, delta_raw, sigma_raw, ascending}, {patch})
trainset = mnist.traindataset()
testset = mnist.testdataset()
x = torch.zeros(n_data, A, A)
for i = 1, n_data do
x[{{i}, {}, {}}] = trainset[i].x:gt(125)
end
ascending = torch.zeros(n_data, A)
for k = 1, n_data do
for i = 1, A do
ascending[k][i] = i
end
end
gx = torch.zeros(n_data, A)
gy = torch.zeros(n_data, A)
sigma = torch.zeros(n_data, A)
delta = torch.zeros(n_data, A)
z = m:forward({x, gx, gy, delta, sigma, ascending})
print(x:gt(0.5))
print(z:gt(0.5))
torch.save('read_patches', z)