-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.lua
145 lines (135 loc) · 3.95 KB
/
utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
-- convolutions with 'same' option
function fex.conv2(...)
local arg = {...}
if arg[#arg] == 'S' then
local ro,x,k = nil,nil,nil
if #arg == 4 then
ro = arg[1]
x = arg[2]
k = arg[3]
else
x = arg[1]
k = arg[2]
ro = x.new(x:size())
end
local r = torch.conv2(x,k,'F')
local shifti = 1+math.ceil((r:size(1)-x:size(1))/2)
local shiftj = 1+math.ceil((r:size(2)-x:size(2))/2)
local ii = r:dim()-1
local jj = r:dim()
ro:resizeAs(x)
ro:copy(r:narrow(ii,shifti,x:size(1)):narrow(jj,shiftj,x:size(2)))
return ro
else
return torch.conv2(...)
end
end
-- cross correlations with 'same' option
function fex.xcorr2(...)
local arg = {...}
if arg[#arg] == 'S' then
local ro,x,k = nil,nil,nil
if #arg == 4 then
ro = arg[1]
x = arg[2]
k = arg[3]
else
x = arg[1]
k = arg[2]
ro = x.new(x:size())
end
local r = torch.xcorr2(x,k,'F')
local shifti = 1+math.ceil((r:size(1)-x:size(1))/2)
local shiftj = 1+math.ceil((r:size(2)-x:size(2))/2)
local ii = r:dim()-1
local jj = r:dim()
ro:resizeAs(x)
ro:copy(r:narrow(ii,shifti,x:size(1)):narrow(jj,shiftj,x:size(2)))
return ro
else
return torch.xcorr2(...)
end
end
-- numerical gradient of a tensor.
function fex.gradient(x)
-- if not dim then dim = torch.range(0,x:dim()):narrow(1,2,x:dim()) end
-- if type(dim) == 'number' then dim = torch.Tensor({dim}) end
local ndim = x:dim()
local function grad(x,dim)
local sz = x:size()
if sz[dim] == 1 then return x:clone():zero() end
sz[dim] = sz[dim]+2
local xx = x.new(sz):zero()
-- copy center
xx:narrow(dim,2,x:size(dim)):copy(x)
-- extrapolate the beginning
local ff = xx:narrow(dim,1,1)
local f1 = xx:narrow(dim,2,1)
local f2 = xx:narrow(dim,3,1)
torch.add(ff,f1,-1,f2)
ff:add(f1)
-- extrapolate the ending
local xend = xx:size(dim)
local fe = xx:narrow(dim,xend,1)
local ff1 = xx:narrow(dim,xend-1,1)
local ff2 = xx:narrow(dim,xend-2,1)
torch.add(fe,ff1,-1,ff2)
fe:add(ff1)
-- now subtract
local d = xx:narrow(dim,3,xend-2):clone()
d:add(-1,xx:narrow(dim,1,xend-2))
return d:div(2)
end
local res = {}
for i=1,ndim do
table.insert(res,i,grad(x,ndim-i+1))
end
return unpack(res)
end
local function dimnarrow(x,sz,pad,dim)
local xn = x
for i=1,x:dim() do
if i > dim then
xn = xn:narrow(i,pad[i]+1,sz[i])
end
end
return xn
end
local function padzero(x,pad)
local sz = x:size()
for i=1,x:dim() do sz[i] = sz[i]+pad[i]*2 end
local xx = x.new(sz):zero()
local xn = dimnarrow(xx,x:size(),pad,-1)
xn:copy(x)
return xx
end
local function padmirror(x,pad)
local xx = padzero(x,pad)
local sz = xx:size()
for i=1,x:dim() do
local xxn = dimnarrow(xx,x:size(),pad,i)
for j=1,pad[i] do
xxn:select(i,j):copy(xxn:select(i,pad[i]*2-j+1))
xxn:select(i,sz[i]-j+1):copy(xxn:select(i,sz[i]-pad[i]*2+j))
end
end
return xx
end
function fex.padarray(x,pad,padtype)
if x:dim() ~= #pad then
error('number of dimensions of Input should match number of padding sizes')
end
if padtype == 'zero' then return padzero(x,pad) end
if padtype == 'mirror' then return padmirror(x,pad) end
error('unknown paddtype ' .. padtype)
end
function fex.repmat(x,dims)
if x:dim() ~= #dims then
error('number of replication dims should be equal to number of dimensions of tensor')
end
local sz = torch.LongStorage(dims)
for i=1,sz:size() do
sz[i] = sz[i]*dims[i]
end
local rm = x.new():resize(sz)
end