77
88contains
99
10- module function conv2d_layer_cons (filters , kernel_size , activation ) result(res)
10+ module function conv2d_layer_cons (filters , kernel_size , activation , stride ) result(res)
1111 implicit none
1212 integer , intent (in ) :: filters
1313 integer , intent (in ) :: kernel_size
1414 class(activation_function), intent (in ) :: activation
15+ integer , intent (in ) :: stride(:)
1516 type (conv2d_layer) :: res
1617
1718 res % kernel_size = kernel_size
1819 res % filters = filters
1920 res % activation_name = activation % get_name()
21+ res % stride = stride
2022 allocate ( res % activation, source = activation )
2123
2224 end function conv2d_layer_cons
@@ -28,8 +30,12 @@ module subroutine init(self, input_shape)
2830 integer , intent (in ) :: input_shape(:)
2931
3032 self % channels = input_shape(1 )
31- self % width = input_shape(2 ) - self % kernel_size + 1
32- self % height = input_shape(3 ) - self % kernel_size + 1
33+
34+ self % width = (input_shape(2 ) - self % kernel_size) / self % stride(1 ) + 1
35+ if (mod (input_shape(2 ) - self % kernel_size , self % stride(1 )) /= 0 ) self % width = self % width + 1
36+
37+ self % height = (input_shape(3 ) - self % kernel_size) / self % stride(2 ) + 1
38+ if (mod (input_shape(3 ) - self % kernel_size , self % stride(2 )) /= 0 ) self % height = self % height + 1
3339
3440 ! Output of shape filters x width x height
3541 allocate (self % output(self % filters, self % width, self % height))
@@ -83,25 +89,24 @@ pure module subroutine forward(self, input)
8389 ! of the input that correspond to the center of each window.
8490 istart = half_window + 1 ! TODO kernel_width
8591 jstart = half_window + 1 ! TODO kernel_height
86- iend = input_width - istart + 1
87- jend = input_height - jstart + 1
8892
89- convolution: do concurrent(i = istart:iend , j = jstart:jend )
93+ convolution: do concurrent(i = 1 :self % width , j = 1 :self % height )
9094
9195 ! Start and end indices of the input data on the filter window
9296 ! iws and jws are also coincidentally the indices of the output matrix
93- iws = i - half_window ! TODO kernel_width
94- iwe = i + half_window ! TODO kernel_width
95- jws = j - half_window ! TODO kernel_height
96- jwe = j + half_window ! TODO kernel_height
97+ iws = istart + self % stride(1 ) * (i-1 ) - half_window ! TODO kernel_width
98+ iwe = min (iws + 2 * half_window, input_width) ! TODO kernel_width
99+
100+ jws = jstart + self % stride(2 ) * (j-1 ) - half_window ! TODO kernel_height
101+ jwe = min (jws + 2 * half_window, input_height) ! TODO kernel_height
97102
98103 ! Compute the inner tensor product, sum(w_ij * x_ij), for each filter.
99104 do concurrent(n = 1 :self % filters)
100- self % z(n,iws,jws ) = sum (self % kernel(n,:,:,: ) * input(:,iws:iwe,jws:jwe))
105+ self % z(n,i,j ) = sum (self % kernel(n,:,1 :iwe - iws +1 , 1 :jwe - jws +1 ) * input(:,iws:iwe,jws:jwe))
101106 end do
102107
103108 ! Add bias to the inner product.
104- self % z(:,iws,jws ) = self % z(:,iws,jws ) + self % biases
109+ self % z(:,i,j ) = self % z(:,i,j ) + self % biases
105110
106111 end do convolution
107112
@@ -156,21 +161,22 @@ pure module subroutine backward(self, input, gradient)
156161 do concurrent( &
157162 n = 1 :self % filters, &
158163 k = 1 :self % channels, &
159- i = istart:iend , &
160- j = jstart:jend &
164+ i = 1 :self % width , &
165+ j = 1 :self % height &
161166 )
162167 ! Start and end indices of the input data on the filter window
163- iws = i - half_window ! TODO kernel_width
164- iwe = i + half_window ! TODO kernel_width
165- jws = j - half_window ! TODO kernel_height
166- jwe = j + half_window ! TODO kernel_height
168+ iws = istart + self % stride(1 ) * (i-1 ) - half_window ! TODO kernel_width
169+ iwe = min (iws + 2 * half_window, input_width) ! TODO kernel_width
170+
171+ jws = jstart + self % stride(2 ) * (j-1 ) - half_window ! TODO kernel_height
172+ jwe = min (jws + 2 * half_window, input_height) ! TODO kernel_height
167173
168- ! dL/dw = sum(dL/dy * sigma'(z) * x)
169- dw(n,k,:,:) = dw(n,k,:,:) + input(k,iws:iwe,jws:jwe) * gdz(n,iws:iwe,jws:jwe )
174+ ! dL/dw = sum(gdz * x)
175+ dw(n,k,:,:) = dw(n,k,:,:) + input(k,iws:iwe,jws:jwe) * gdz(n,i,j )
170176
171- ! dL/dx = dL/dy * sigma'(z) .inner. w
172- self % gradient(k,i,j ) = self % gradient(k,i,j ) &
173- + sum ( gdz(n,iws:iwe,jws:jwe ) * self % kernel(n,k,:,:) )
177+ ! dL/dx = sum(gdz * w)
178+ self % gradient(k,iws:iwe,jws:jwe ) = self % gradient(k,iws:iwe,jws:jwe ) &
179+ + gdz(n,i,j ) * self % kernel(n,k,:,:)
174180
175181 end do
176182
0 commit comments