-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c71518a
commit deaa6dc
Showing
7 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
function [Y, Y_w] = vl_myfc(X, W, dzdy) | ||
%regular fully connected layer | ||
|
||
[n1,n2,n3,n4,n5] = size(X); | ||
|
||
X_t = reshape(X, n1*n2*n3*n4,n5); | ||
|
||
if nargin < 3 | ||
Y = W * X_t; | ||
else | ||
Y = W' * dzdy; | ||
Y_w = dzdy*X_t'; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
function res = vl_myforbackward(net, x, dzdy, res, varargin) | ||
% VL_SIMPLENN Evaluates a simple LieNet | ||
|
||
opts.res = [] ; | ||
opts.conserveMemory = false ; | ||
opts.sync = false ; | ||
opts.disableDropout = false ; | ||
opts.freezeDropout = false ; | ||
opts.accumulate = false ; | ||
opts.cudnn = true ; | ||
opts.skipForward = false; | ||
opts.backPropDepth = +inf ; | ||
opts.epsilon = 5e-1; | ||
|
||
% opts = vl_argparse(opts, varargin); | ||
|
||
n = numel(net.layers) ; | ||
|
||
if (nargin <= 2) || isempty(dzdy) | ||
doder = false ; | ||
else | ||
doder = true ; | ||
end | ||
|
||
if opts.cudnn | ||
cudnn = {'CuDNN'} ; | ||
else | ||
cudnn = {'NoCuDNN'} ; | ||
end | ||
|
||
gpuMode = isa(x, 'gpuArray') ; | ||
|
||
if nargin <= 3 || isempty(res) | ||
res = struct(... | ||
'x', cell(1,n+1), ... | ||
'dzdx', cell(1,n+1), ... | ||
'dzdw', cell(1,n+1), ... | ||
'aux', cell(1,n+1), ... | ||
'time', num2cell(zeros(1,n+1)), ... | ||
'backwardTime', num2cell(zeros(1,n+1))) ; | ||
end | ||
if ~opts.skipForward | ||
res(1).x = x ; | ||
end | ||
|
||
|
||
% ------------------------------------------------------------------------- | ||
% Forward pass | ||
% ------------------------------------------------------------------------- | ||
for i=1:n | ||
if opts.skipForward, break; end; | ||
l = net.layers{i} ; | ||
res(i).time = tic ; | ||
switch l.type | ||
case 'rotmap' | ||
res(i+1).x = vl_myrotmap(res(i).x, l.weight) ; | ||
|
||
case 'logmap' | ||
[res(i+1).x, res(i)] = vl_mylogmap(res(i)) ; | ||
|
||
case 'relu' | ||
[res(i+1).x, res(i)] = vl_myrelu(res(i)) ; | ||
|
||
case 'pooling' | ||
[res(i+1).x, res(i)] = vl_mypooling(res(i), l.pool) ; | ||
|
||
case 'fc' | ||
res(i+1).x = vl_myfc(res(i).x, l.weight) ; | ||
case 'softmaxloss' | ||
res(i+1).x = vl_mysoftmaxloss(res(i).x, l.class) ; | ||
case 'custom' | ||
res(i+1) = l.forward(l, res(i), res(i+1)) ; | ||
otherwise | ||
error('Unknown layer type %s', l.type) ; | ||
end | ||
% optionally forget intermediate results | ||
forget = opts.conserveMemory ; | ||
forget = forget & (~doder || strcmp(l.type, 'relu')) ; | ||
forget = forget & ~(strcmp(l.type, 'loss') || strcmp(l.type, 'softmaxloss')) ; | ||
forget = forget & (~isfield(l, 'rememberOutput') || ~l.rememberOutput) ; | ||
if forget | ||
res(i).x = [] ; | ||
end | ||
if gpuMode & opts.sync | ||
% This should make things slower, but on MATLAB 2014a it is necessary | ||
% for any decent performance. | ||
wait(gpuDevice) ; | ||
end | ||
res(i).time = toc(res(i).time) ; | ||
end | ||
|
||
% ------------------------------------------------------------------------- | ||
% Backward pass | ||
% ------------------------------------------------------------------------- | ||
|
||
if doder | ||
res(n+1).dzdx = dzdy ; | ||
for i=n:-1:max(1, n-opts.backPropDepth+1) | ||
l = net.layers{i} ; | ||
res(i).backwardTime = tic ; | ||
switch l.type | ||
case 'rotmap' | ||
[res(i).dzdx, res(i).dzdw] = ... | ||
vl_myrotmap(res(i).x, l.weight, res(i+1).dzdx) ; | ||
|
||
case 'logmap' | ||
res(i).dzdx = vl_mylogmap(res(i), res(i+1).dzdx) ; | ||
|
||
case 'relu' | ||
res(i).dzdx = vl_myrelu(res(i), res(i+1).dzdx) ; | ||
|
||
case 'pooling' | ||
res(i).dzdx = vl_mypooling(res(i), l.pool, res(i+1).dzdx) ; | ||
|
||
case 'fc' | ||
[res(i).dzdx, res(i).dzdw] = ... | ||
vl_myfc(res(i).x, l.weight, res(i+1).dzdx) ; | ||
case 'softmaxloss' | ||
res(i).dzdx = vl_mysoftmaxloss(res(i).x, l.class, res(i+1).dzdx) ; | ||
case 'custom' | ||
res(i) = l.backward(l, res(i), res(i+1)) ; | ||
end | ||
if opts.conserveMemory | ||
res(i+1).dzdx = [] ; | ||
end | ||
if gpuMode & opts.sync | ||
wait(gpuDevice) ; | ||
end | ||
res(i).backwardTime = toc(res(i).backwardTime) ; | ||
end | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
function [Y, R] = vl_mylogmap(R, dzdy) | ||
%logarithm mapping (LogMap) layer | ||
|
||
X = R.x; | ||
A = R.aux; | ||
[n1,n2,n3,n4,n5] = size(X); | ||
Y = zeros(n1,n2,n3,n4,n5); | ||
|
||
|
||
if isempty(A) == 1 | ||
parfor i3 = 1 : n3 | ||
for i4 = 1 : n4 | ||
for i5 = 1 : n5 | ||
X_t = X(:,:,i3,i4,i5); | ||
Y_t = zeros(n1,n2); | ||
axis_angle = vrrotmat2vec_modified(X_t); | ||
Y_t([1 2 3 6]) = axis_angle; | ||
Y(:,:,i3,i4,i5) = Y_t; | ||
end | ||
end | ||
end | ||
R.aux = 1; | ||
else | ||
dzdy = reshape(dzdy,n1,n2,n3,n4,n5); | ||
parfor i3 = 1 : n3 | ||
for i4 = 1 : n4 | ||
for i5 = 1 : n5 | ||
X_t = X(:,:,i3,i4,i5); D_t = dzdy(:,:,i3,i4,i5); | ||
Y(:,:,i3,i4,i5) = calculate_grad_log_angel(X_t,D_t,n1); | ||
end | ||
end | ||
end | ||
end | ||
|
||
function dzdx = calculate_grad_log_angel(X_t,D_t,n1) | ||
|
||
epsilon = 1e-12; | ||
dzdx_t = zeros(n1,n1); | ||
|
||
if abs(trace(X_t) - 3) <= epsilon || abs(trace(X_t) + 1) <= epsilon | ||
dzdx_t = zeros(n1,n1); | ||
else | ||
X_s = (X_t(2,1)-X_t(1,2))^2+(X_t(3,1)-X_t(1,3))^2+(X_t(3,2)-X_t(2,3))^2; | ||
X_m = (X_t(1,1)+X_t(2,2)+X_t(3,3)-1)/2; | ||
|
||
dzdx_t(1) = -1/sqrt(1-(X_m)^2)*0.5; | ||
dzdx_t(2) = X_s^(-0.5)-(X_t(2,1)-X_t(1,2))^2*X_s^(-1.5); | ||
dzdx_t(3) = X_s^(-0.5)-(X_t(3,1)-X_t(1,3))^2*X_s^(-1.5); | ||
dzdx_t(6) = X_s^(-0.5)-(X_t(3,2)-X_t(2,3))^2*X_s^(-1.5); | ||
|
||
end | ||
|
||
dzdx = D_t.*dzdx_t; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
function [Y, R] = vl_mypooling(R, pool, dzdy) | ||
%logarithm mapping (LogMap) layer | ||
X = R.x; | ||
A = R.aux; | ||
[n1,n2,n3,n4,n5] = size(X); | ||
if pool == 2 | ||
Y = zeros(n1,n2,ceil(n3/pool),n4,n5); | ||
else | ||
Y = zeros(n1,n2,n3,ceil(n4/pool),n5); | ||
end | ||
|
||
if isempty(A)==1 | ||
IY = zeros(n1,n2,n3,n4,n5); | ||
Id = 1 : n1*n2; | ||
if pool == 2 | ||
for i4 = 1 : n4 | ||
for i5 = 1 : n5 | ||
for i3 = 1 : pool : n3 | ||
if i3+(pool-1) <=n3 | ||
r_tt = X(:,:, i3:i3+(pool-1),i4,i5); | ||
[m,I] = maxRotAngel(r_tt); | ||
I1 = I*ones(n1*n2,1); | ||
tI = zeros(n1,n2,pool); | ||
tI((I1-1)*n1*n2+Id') = 1; | ||
IY(:,:, i3: i3+(pool-1),i4,i5) = tI; | ||
else | ||
r_tt = X(:,:, i3:end,i4,i5); | ||
[m,I] = maxRotAngel(r_tt); | ||
I1 = I*ones(n1*n2,1); | ||
tI = zeros(n1,n2,n3-i3+1); | ||
tI((I1-1)*n1*n2+Id') = 1; | ||
IY(:,:, i3: end,i4,i5) = tI; | ||
end | ||
Y(:,:,floor(i3/pool)+1,i4,i5) = r_tt(:,:,I); | ||
end | ||
end | ||
end | ||
else | ||
for i3 = 1 : n3 | ||
for i5 = 1 : n5 | ||
for i4 = 1 : pool : n4 | ||
if i4+(pool-1) <=n4 | ||
r_tt = X(:,:,i3, i4:i4+(pool-1),i5); | ||
r_tt = reshape(r_tt, n1,n2,pool); | ||
[m,I] = maxRotAngel(r_tt); | ||
I1 = I*ones(n1*n2,1); | ||
tI = zeros(n1,n2,pool); | ||
tI((I1-1)*n1*n2+Id') = 1; | ||
IY(:,:,i3, i4: i4+(pool-1),i5) = tI; | ||
else | ||
r_tt = X(:,:, i3,i4:end,i5); | ||
r_tt = reshape(r_tt, n1,n2,n4-i4+1); | ||
|
||
[m,I] = maxRotAngel(r_tt); | ||
I1 = I*ones(n1*n2,1); | ||
tI = zeros(n1,n2,n4-i4+1); | ||
tI((I1-1)*n1*n2+Id') = 1; | ||
IY(:,:, i3,i4: end,i5) = tI; | ||
end | ||
Y(:,:,i3,floor(i4/pool)+1,i5) = r_tt(:,:,I); | ||
end | ||
end | ||
end | ||
end | ||
R.aux = IY; | ||
else | ||
Y = zeros(n1,n2,n3,n4,n5); | ||
Y(logical(A)) = dzdy; | ||
|
||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
function [Y, R] = vl_myrelu(R, dzdy) | ||
%ReLU layer | ||
X = R.x; | ||
A = R.aux; | ||
[n1,n2,n3,n4,n5] = size(X); | ||
Y = zeros(n1,n2,n3,n4,n5); | ||
epslon = 0.3; | ||
|
||
if isempty(A) ==1 | ||
A = zeros(n1,n2,n3,n4,n5); | ||
parfor i3 = 1 : n3 | ||
for i4 = 1 : n4 | ||
for i5 = 1 : n5 | ||
r_t = X(:,:,i3,i4,i5); | ||
Ir_t1 = zeros(3,3); | ||
Ir_t1([1 2 3 6]) = abs(r_t([1 2 3 6])) < epslon; | ||
Ir_t3 = r_t < 0; | ||
r_t(Ir_t1 & Ir_t3) = -epslon; | ||
r_t(Ir_t1 & ~Ir_t3) = epslon; | ||
A(:,:,i3,i4,i5) = Ir_t1; | ||
Y(:,:,i3,i4,i5) = r_t; | ||
end | ||
end | ||
end | ||
R.aux = A; | ||
else | ||
dzdy = reshape(dzdy,n1,n2,n3,n4,n5); | ||
|
||
parfor i3 = 1 : n3 | ||
for i4 = 1 : n4 | ||
for i5 = 1 : n5 | ||
dzdy_t = dzdy(:,:,i3,i4,i5); | ||
Ir1 = A(:,:,i3,i4,i5); | ||
dzdy_t = dzdy_t .* not(Ir1);%through ReLU | ||
Y(:,:,i3,i4,i5) = dzdy_t; | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
function [Y, Y_w] = vl_myrotmap(X, W, dzdy) | ||
%rotation mapping (RotMap) layers | ||
|
||
[n1,n2,n3,n4,n5] = size(X); | ||
Y = zeros(n1,n2,n3,n4,n5); | ||
|
||
if nargin < 3 | ||
parfor i3 = 1 : n3 | ||
for i4 = 1 : n4 | ||
for i5 = 1 : n5 | ||
Y(:,:,i3,i4,i5) = W(:,:,i3)*X(:,:,i3,i4,i5); | ||
end | ||
end | ||
end | ||
else | ||
Y_w = zeros(n1,n2,n3); | ||
dzdy = reshape(dzdy,n1,n2,n3,n4,n5); | ||
parfor i3 = 1 : n3 | ||
for i4 = 1 : n4 | ||
for i5 = 1 : n5 | ||
d_t = dzdy(:,:,i3,i4,i5); | ||
Y(:,:,i3,i4,i5) = W(:,:,i3)'*d_t; | ||
Y_w(:,:,i3) = Y_w(:,:,i3)+ d_t*X(:,:,i3,i4,i5)'; | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
function Y = vl_mysoftmaxloss(X,c,dzdy) | ||
|
||
% classNum = size(c,2); | ||
% s = X; | ||
% s = bsxfun(@minus, s, max(s, [], 1)); | ||
% s = exp(s); | ||
% s = s + 1e-8; %avoid NaN | ||
% y = s./repmat(sum(s,1),[classNum 1]); | ||
% if nargin < 3 | ||
% loss = -c'.* log(y); | ||
% Y = sum(loss(:)); | ||
% else | ||
% % gradient computation | ||
% Y =dzdy* (-1) *(c'-y);%loss delta or output delta | ||
% end | ||
|
||
% class c = 0 skips a spatial location | ||
mass = single(c > 0) ; | ||
mass = mass'; | ||
|
||
% convert to indexes | ||
c_ = c - 1 ; | ||
for ic = 1 : length(c) | ||
c_(ic) = c(ic)+(ic-1)*size(X,1); | ||
end | ||
|
||
% compute softmaxloss | ||
Xmax = max(X,[],1) ; | ||
ex = exp(bsxfun(@minus, X, Xmax)) ; | ||
|
||
% s = bsxfun(@minus, X, Xmax); | ||
% ex = exp(s) ; | ||
% y = ex./repmat(sum(ex,1),[size(X,1) 1]); | ||
|
||
%n = sz(1)*sz(2) ; | ||
if nargin < 3 | ||
t = Xmax + log(sum(ex,1)) - reshape(X(c_), [1 size(X,2)]) ; | ||
Y = sum(sum(mass .* t,1)) ; | ||
else | ||
Y = bsxfun(@rdivide, ex, sum(ex,1)) ; | ||
Y(c_) = Y(c_) - 1; | ||
Y = bsxfun(@times, Y, bsxfun(@times, mass, dzdy)) ; | ||
end |