-
Notifications
You must be signed in to change notification settings - Fork 0
/
pooling_layer_backward.m
55 lines (48 loc) · 2.2 KB
/
pooling_layer_backward.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
function [input_od] = pooling_layer_backward(output, input, layer)
%% function input:
% input: input of pooling_layer_forward
% output: output of pooling_layer_forward
% layer.k: kernel size of pooling operation
% layer.stride: stride of pooling operation
% layer.pad: pad of pooling operation
stride=layer.stride;
k=layer.k;
%% function output
% input_od: gradient w.r.t input.data
% initialize
input_od = zeros(size(input.data));
modin=reshape(input.data,[input.height,input.width,input.channel,input.batch_size]);
modout=reshape(output.data,[output.height,output.width,output.channel,output.batch_size]);
moddiff=reshape(output.diff,[output.height,output.width,output.channel,output.batch_size]);
modinod=reshape(input_od,[input.height,input.width,input.channel,input.batch_size]);
switch layer.act_type
case 'MAX'
% work out the max pooling backward and compute input_od
for b = 1:input.batch_size
for ch = 1:input.channel
for m = 1:output.height
for n=1:output.width
temp=modin((m-1)*stride + 1 : (m-1)*stride + k, (n-1)*stride + 1 : (n-1)*stride + k,ch,b);
temp2 = (temp==modout(m,n,ch,b)).*moddiff(m,n,ch,b);
modinod((m-1)*stride + 1 : (m-1)*stride + k, (n-1)*stride + 1 : (n-1)*stride + k,ch,b)=modinod((m-1)*stride + 1 : (m-1)*stride + k, (n-1)*stride + 1 : (n-1)*stride + k,ch,b)+temp2;
end
end
end
end
input_od = reshape(modinod,size(input.data));
case 'AVE'
for b = 1:input.batch_size
for ch = 1:input.channel
for m = 1:output.height
for n=1:output.width
temp=ones(k,k);
temp2 = temp.*(moddiff(m,n,ch,b)/k/k);
modinod((m-1)*stride + 1 : (m-1)*stride + k, (n-1)*stride + 1 : (n-1)*stride + k,ch,b)=modinod((m-1)*stride + 1 : (m-1)*stride + k, (n-1)*stride + 1 : (n-1)*stride + k,ch,b)+temp2;
end
end
end
end
input_od = reshape(modinod,size(input.data));
% work out the ave pooling backward and compute input_od
end
end