-
Notifications
You must be signed in to change notification settings - Fork 479
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYSTEMDS-3651] ResNet residual block forward path
This patch implements the forward path for the basic residual block of ResNet and applies differential testing comparing results with PyTorch. The testing is done by, first, testing for the correct handling of dimension throughout the residual bock. Secondly, we test for correct computation of the actual values of the output by using the torchvision.models.resnet.BasicBlock of PyTorch which is there implementation of the basic block for the ResNets 18 and 34. We let PyTorch randomly initialize all weights and then hard-code these weights into the component test with the expected values as well. Closes #1940
- Loading branch information
1 parent
b6080cf
commit 9de62d1
Showing
4 changed files
with
1,042 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,364 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
source("scripts/nn/layers/batch_norm2d_old.dml") as bn2d | ||
source("scripts/nn/layers/conv2d_builtin.dml") as conv2d | ||
source("scripts/nn/layers/relu.dml") as relu | ||
source("scripts/nn/layers/max_pool2d_builtin.dml") as mp2d | ||
source("scripts/nn/layers/global_avg_pool2d.dml") as ap2d | ||
source("scripts/nn/layers/affine.dml") as fc | ||
|
||
conv3x3_forward = function(matrix[double] X, matrix[double] filter, | ||
int C_in, int C_out, int Hin, int Win, | ||
int strideh, int stridew) | ||
return(matrix[double] out, int Hout, int Wout) { | ||
/* | ||
* Simple 2D conv layer with 3x3 filter | ||
*/ | ||
# bias should not be applied | ||
bias = matrix(0, C_out, 1) | ||
[out, Hout, Wout] = conv2d::forward(X, filter, bias, C_in, Hin, Win, | ||
3, 3, strideh, stridew, 1, 1) | ||
} | ||
|
||
conv1x1_forward = function(matrix[double] X, matrix[double] filter, | ||
int C_in, int C_out, int Hin, int Win, | ||
int strideh, int stridew) | ||
return(matrix[double] out, int Hout, int Wout) { | ||
/* | ||
* Simple 2D conv layer with 1x1 filter | ||
*/ | ||
# bias should not be applied | ||
bias = matrix(0, C_out, 1) | ||
[out, Hout, Wout] = conv2d::forward(X, filter, bias, C_in, Hin, Win, | ||
1, 1, strideh, stridew, 0, 0) | ||
} | ||
|
||
basic_block_forward = function(matrix[double] X, list[unknown] weights, | ||
int C_in, int C_base, int Hin, int Win, | ||
int strideh, int stridew, string mode, | ||
list[unknown] ema_means_vars) | ||
return (matrix[double] out, int Hout, int Wout, | ||
list[unknown] ema_means_vars_upd) { | ||
/* | ||
* Computes the forward pass for a basic residual block. | ||
* This basic residual block (with 2 3x3 conv layers of | ||
* same channel size) is used in the smaller ResNets 18 | ||
* and 34. | ||
* | ||
* Inputs: | ||
* - X: Inputs, of shape (N, C_in*Hin*Win). | ||
* - weights: list of weights for all layers of res block | ||
* with the following order/content: | ||
* -> 1: Weights of conv 1, of shape (C_base, C_in*3*3). | ||
* -> 2: Weights of batch norm 1, of shape (C_base, 1). | ||
* -> 3: Bias of batch norm 1, of shape (C_base, 1). | ||
* -> 4: Weights of conv 2, of shape (C_base, C_base*3*3). | ||
* -> 5: Weights of batch norm 2, of shape (C_base, 1). | ||
* -> 6: Bias of batch norm 2, of shape (C_base, 1). | ||
* If the block should downsample X: | ||
* -> 7: Weights of downsample conv, of shape (C_base, C_in*3*3). | ||
* -> 8: Weights of downsample batch norm, of shape (C_base, 1). | ||
* -> 9: Bias of downsample batch norm, of shape (C_base, 1). | ||
* - C_in: Number of input channels. | ||
* - C_base: Number of base channels for this block. | ||
* - Hin: Input height. | ||
* - Win: Input width. | ||
* - strideh: Stride over height (usually 1 or 2).. | ||
* - stridew: Stride over width (usually same as strideh). | ||
* - mode: 'train' or 'test' to indicate if the model is currently | ||
* being trained or tested for badge normalization layers. | ||
* See badge_norm2d.dml docs for more info. | ||
* - ema_means_vars: List of exponential moving averages for mean | ||
* and variance for badge normalization layers. | ||
* -> 1: EMA for mean of badge norm 1, of shape (C_base, 1). | ||
* -> 2: EMA for variance of badge norm 1, of shape (C_base, 1). | ||
* -> 3: EMA for mean of badge norm 2, of shape (C_base, 1). | ||
* -> 4: EMA for variance of badge norm 2, of shape (C_base, 1). | ||
* If the block should downsample X: | ||
* -> 5: EMA for mean of downs. badge norm, of shape (C_base, 1). | ||
* -> 6: EMA for variance of downs. badge norm, of shape (C_base, 1). | ||
* | ||
* Outputs: | ||
* - out: Output, of shape (N, C_base*Hout*Wout). | ||
* - Hout: Output height. | ||
* - Wout: Output width. | ||
* - ema_means_vars_upd: List of updated exponential moving averages | ||
* for mean and variance of badge normalization layers. | ||
*/ | ||
downsample = strideh > 1 | stridew > 1 | C_in != C_base | ||
# default values | ||
mu_bn = 0.1 | ||
epsilon_bn = 1e-05 | ||
|
||
# get all params | ||
W_conv1 = as.matrix(weights[1]) | ||
gamma_bn1 = as.matrix(weights[2]) | ||
beta_bn1 = as.matrix(weights[3]) | ||
W_conv2 = as.matrix(weights[4]) | ||
gamma_bn2 = as.matrix(weights[5]) | ||
beta_bn2 = as.matrix(weights[6]) | ||
|
||
ema_mean_bn1 = as.matrix(ema_means_vars[1]) | ||
ema_var_bn1 = as.matrix(ema_means_vars[2]) | ||
ema_mean_bn2 = as.matrix(ema_means_vars[3]) | ||
ema_var_bn2 = as.matrix(ema_means_vars[4]) | ||
|
||
if (downsample) { | ||
# gather params for donwsampling | ||
W_conv3 = as.matrix(weights[7]) | ||
gamma_bn3 = as.matrix(weights[8]) | ||
beta_bn3 = as.matrix(weights[9]) | ||
ema_mean_bn3 = as.matrix(ema_means_vars[5]) | ||
ema_var_bn3 = as.matrix(ema_means_vars[6]) | ||
} | ||
|
||
# RESIDUAL PATH | ||
# First convolutional layer | ||
[out, Hout, Wout] = conv3x3_forward(X, W_conv1, C_in, C_base, Hin, Win, | ||
strideh, stridew) | ||
[out, ema_mean_bn1_upd, ema_var_bn1_upd, c_m, c_v] = bn2d::forward(out, gamma_bn1, | ||
beta_bn1, C_base, Hout, Wout, | ||
mode, ema_mean_bn1, ema_var_bn1, | ||
mu_bn, epsilon_bn) | ||
out = relu::forward(out) | ||
|
||
# Second convolutional layer | ||
[out, Hout, Wout] = conv3x3_forward(out, W_conv2, C_base, C_base, Hout, | ||
Wout, 1, 1) | ||
[out, ema_mean_bn2_upd, ema_var_bn2_upd, c_m, c_v] = bn2d::forward(out, gamma_bn2, | ||
beta_bn2, C_base, Hout, Wout, | ||
mode, ema_mean_bn2, ema_var_bn2, | ||
mu_bn, epsilon_bn) | ||
|
||
# IDENTITY PATH | ||
identity = X | ||
if (downsample) { | ||
# downsample input | ||
[identity, Hout, Wout] = conv1x1_forward(X, W_conv3, C_in, C_base, | ||
Hin, Win, strideh, stridew) | ||
[identity, ema_mean_bn3_upd, ema_var_bn3_upd, c_m, c_v] = bn2d::forward(identity, | ||
gamma_bn3, beta_bn3, C_base, Hout, Wout, | ||
mode, ema_mean_bn3, ema_var_bn3, mu_bn, | ||
epsilon_bn) | ||
} | ||
|
||
out = relu::forward(out + identity) | ||
|
||
ema_means_vars_upd = list(ema_mean_bn1_upd, ema_var_bn1_upd, ema_mean_bn2_upd, ema_var_bn2_upd) | ||
if (downsample) { | ||
ema_means_vars_upd = append(ema_means_vars, ema_mean_bn3_upd) | ||
ema_means_vars_upd = append(ema_means_vars, ema_var_bn3_upd) | ||
} | ||
} | ||
|
||
basic_reslayer_forward = function(matrix[double] X, int Hin, int Win, int blocks, | ||
int strideh, int stridew, int C_in, int C_base, | ||
list[unknown] blocks_weights, string mode, | ||
list[unknown] ema_means_vars) | ||
return (matrix[double] out, int Hout, int Wout, | ||
list[unknown] ema_means_vars_upd) { | ||
/* | ||
* Executes the forward pass for a sequence of residual blocks | ||
* with the same number of base channels, i.e. residual layer. | ||
* | ||
* Inputs: | ||
* - X: Inputs, of shape (N, C_in*Hin*Win) | ||
* - Hin: Input height. | ||
* - Win: Input width. | ||
* - blocks: Number of residual blocks (bigger than 0). | ||
* - strideh: Stride height for first conv layer of first block. | ||
* - stridew: Stride width for first conv layer of first block. | ||
* - C_in: Number of input channels. | ||
* - C_base: Number of base channels of res layer. | ||
* - blocks_weights: List of weights of each block. | ||
* -> i: List of weights of block i with the content | ||
* defined in the docs of basic_block_forward(). | ||
* -> length == blocks | ||
* - mode: 'train' or 'test' to indicate if the model is currently | ||
* being trained or tested for badge normalization layers. | ||
* See badge_norm2d.dml docs for more info. | ||
* - ema_means_vars: List of exponential moving averages for mean | ||
* and variance for badge normalization layers of each block. | ||
* -> i: List of EMAs of block i with the content defined | ||
* in the docs of basic_block_forward(). | ||
* -> length == blocks | ||
*/ | ||
# default values | ||
mu_bn = 0.1 | ||
epsilon_bn = 1e-05 | ||
|
||
# first block with provided stride | ||
[out, Hout, Wout, emas1_upd] = basic_block_forward(X, as.list(blocks_weights[1]), | ||
C_in, C_base, Hin, Win, strideh, stridew, | ||
mode, as.list(ema_means_vars[1])) | ||
ema_means_vars_upd = list(emas1_upd) | ||
|
||
# other block | ||
for (i in 2:blocks) { | ||
current_weights = as.list(blocks_weights[i]) | ||
current_emas = as.list(ema_means_vars[i]) | ||
[out, Hout, Wout, current_emas_upd] = basic_block_forward(X=out, | ||
weights=current_weights, C_in=C_base, C_base=C_base, | ||
Hin=Hout, Win=Wout, strideh=1, stridew=1, mode=mode, | ||
ema_means_vars=current_emas) | ||
ema_means_vars_upd = append(ema_means_vars_upd, current_emas_upd) | ||
} | ||
} | ||
|
||
resnet18_forward = function(matrix[double] X, int Hin, int Win, | ||
list[unknown] model, string mode, | ||
list[unknown] ema_means_vars) | ||
return (matrix[double] out, list[unknown] ema_means_vars_upd) { | ||
/* | ||
* Forward pass of the ResNet 18 model as introduced in | ||
* "Deep Residual Learning for Image Recognition" by | ||
* Kaiming He et. al. and inspired by the PyTorch | ||
* implementation. | ||
* | ||
* Inputs: | ||
* - X: Inputs, of shape (N, C_in*Hin*Win). | ||
* C_in = 3 is expected. | ||
* - Hin: Input height. | ||
* - Win: Input width. | ||
* - model: Weights and bias matrices of the model | ||
* with the following order/content: | ||
* -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7) | ||
* -> 2: Weights of batch norm 1, of shape (64, 1). | ||
* -> 3: Bias of batch norm 1, of shape (64, 1). | ||
* -> 4: List of weights for first residual layer | ||
* with 64 base channels. | ||
* -> 5: List of weights for second residual layer | ||
* with 128 base channels. | ||
* -> 6: List of weights for third residual layer | ||
* with 256 base channels. | ||
* -> 7: List of weights for fourth residual layer | ||
* with 512 base channels. | ||
* List of residual layers 1, 2, 3 & 4 have | ||
* the content/order: | ||
* -> 1: List of weights for first residual | ||
* block. | ||
* -> 2: List of weights for second residual | ||
* block. | ||
* Each list of weights for a residual block | ||
* must follow the same order as defined in | ||
* the documentation of basic_block_forward(). | ||
* -> 8: Weights of fully connected layer, of shape (512, 1000) | ||
* -> 9: Bias of fully connected layer, of shape (1, 1000) | ||
* - mode: 'train' or 'test' to indicate if the model is currently | ||
* being trained or tested for badge normalization layers. | ||
* See badge_norm2d.dml docs for more info. | ||
* - ema_means_vars: List of exponential moving averages for mean | ||
* and variance for badge normalization layers. | ||
* -> 1: EMA for mean of badge norm 1, of shape (64, 1). | ||
* -> 2: EMA for variance of badge norm 1, of shape (64, 1). | ||
* -> 3: List of EMA means and vars for residual layer 1. | ||
* -> 4: List of EMA means and vars for residual layer 2. | ||
* -> 5: List of EMA means and vars for residual layer 3. | ||
* -> 6: List of EMA means and vars for residual layer 4. | ||
* Lists for EMAs of layer 1, 2, 3 & 4 must have the | ||
* following order: | ||
* -> 1: List of EMA means and vars for residual block 1. | ||
* -> 2: List of EMA means and vars for residual block 2. | ||
* Each list of EMAs for a residual block | ||
* must follow the same order as defined in | ||
* the documentation of basic_block_forward(). | ||
* - NOTICE: The lists of the first blocks for layer 2, 3 and 4 | ||
* must include weights and EMAs for 1 extra conv layer | ||
* and a batch norm layer for the downsampling on the | ||
* identity path. | ||
* | ||
* Outputs: | ||
* - out: Outputs, of shape (N, 1000) | ||
* - ema_means_vars_upd: List of updated exponential moving averages | ||
* for mean and variance of badge normalization layers. It follows | ||
* the same exact structure as the input EMAs list. | ||
*/ | ||
# default values | ||
mu_bn = 0.1 | ||
epsilon_bn = 1e-05 | ||
|
||
# extract model params | ||
W_conv1 = as.matrix(model[1]) | ||
gamma_bn1 = as.matrix(model[2]); beta_bn1 = as.matrix(model[3]) | ||
weights_reslayer1 = as.list(model[4]) | ||
weights_reslayer2 = as.list(model[5]) | ||
weights_reslayer3 = as.list(model[6]) | ||
weights_reslayer4 = as.list(model[7]) | ||
W_fc = as.matrix(model[8]) | ||
b_fc = as.matrix(model[9]) | ||
ema_mean_bn1 = as.matrix(ema_means_vars[1]); ema_var_bn1 = as.matrix(ema_means_vars[2]) | ||
emas_reslayer1 = as.list(ema_means_vars[3]) | ||
emas_reslayer2 = as.list(ema_means_vars[4]) | ||
emas_reslayer3 = as.list(ema_means_vars[5]) | ||
emas_reslayer4 = as.list(ema_means_vars[6]) | ||
|
||
# Convolutional 7x7 layer | ||
C = 64 | ||
b_conv1 = matrix(0, rows=C, cols=1) | ||
[out, Hout, Wout] = conv2d::forward(X=X, W=W_conv1, b=b_conv1, C=3, | ||
Hin=Hin, Win=Win, Hf=7, Wf=7, strideh=2, | ||
stridew=2, padh=3, padw=3) | ||
# Batch Normalization | ||
[out, ema_mean_bn1_upd, ema_var_bn1_upd, c_mean, c_var] = bn2d::forward(X=out, | ||
gamma=gamma_bn1, beta=beta_bn1, C=C, Hin=Hout, | ||
Win=Wout, mode=mode, ema_mean=ema_mean_bn1, | ||
ema_var=ema_var_bn1, mu=mu_bn, | ||
epsilon=epsilon_bn) | ||
# ReLU | ||
out = relu::forward(X=out) | ||
# Max Pooling 3x3 | ||
[out, Hout, Wout] = mp2d::forward(X=out, C=C, Hin=Hout, Win=Wout, Hf=3, | ||
Wf=3, strideh=2, stridew=2, padh=1, padw=1) | ||
|
||
# residual layer 1 | ||
[out, Hout, Wout, emas1_upd] = basic_reslayer_forward(X=out, Hin=Hout, | ||
Win=Wout, blocks=2, strideh=1, stridew=1, C_in=C, | ||
C_base=64, blocks_weights=weights_reslayer1, | ||
mode=mode, ema_means_vars=emas_reslayer1) | ||
C = 64 | ||
# residual layer 2 | ||
[out, Hout, Wout, emas2_upd] = basic_reslayer_forward(X=out, Hin=Hout, | ||
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C, | ||
C_base=128, blocks_weights=weights_reslayer2, | ||
mode=mode, ema_means_vars=emas_reslayer2) | ||
C = 128 | ||
# residual layer 3 | ||
[out, Hout, Wout, emas3_upd] = basic_reslayer_forward(X=out, Hin=Hout, | ||
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C, | ||
C_base=256, blocks_weights=weights_reslayer3, | ||
mode=mode, ema_means_vars=emas_reslayer3) | ||
C = 256 | ||
# residual layer 4 | ||
[out, Hout, Wout, emas4_upd] = basic_reslayer_forward(X=out, Hin=Hout, | ||
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C, | ||
C_base=512, blocks_weights=weights_reslayer4, | ||
mode=mode, ema_means_vars=emas_reslayer4) | ||
C = 512 | ||
|
||
# Global Average Pooling | ||
[out, Hout, Wout] = ap2d::forward(X=out, C=C, Hin=Hout, Win=Wout) | ||
# Affine | ||
out = fc::forward(X=out, W=W_fc, b=b_fc) | ||
|
||
ema_means_vars_upd = list(ema_mean_bn1_upd, ema_var_bn1_upd, | ||
emas1_upd, emas2_upd, emas3_upd, emas4_upd) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.