Skip to content

Commit

Permalink
[SYSTEMDS-3651] ResNet residual block forward path
Browse files Browse the repository at this point in the history
This patch implements the forward path for the basic residual block
of ResNet and applies differential testing comparing results with PyTorch.
The testing is done by, first, testing for the correct handling of
dimension throughout the residual bock. Secondly, we test for correct
computation of the actual values of the output by using the
torchvision.models.resnet.BasicBlock of PyTorch which is there
implementation of the basic block for the ResNets 18 and 34. We let
PyTorch randomly initialize all weights and then hard-code these
weights into the component test with the expected values as well.

Closes #1940
  • Loading branch information
MaximilianSchreff authored and phaniarnab committed Nov 8, 2023
1 parent b6080cf commit 9de62d1
Show file tree
Hide file tree
Showing 4 changed files with 1,042 additions and 0 deletions.
364 changes: 364 additions & 0 deletions scripts/nn/networks/resnet.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

source("scripts/nn/layers/batch_norm2d_old.dml") as bn2d
source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
source("scripts/nn/layers/relu.dml") as relu
source("scripts/nn/layers/max_pool2d_builtin.dml") as mp2d
source("scripts/nn/layers/global_avg_pool2d.dml") as ap2d
source("scripts/nn/layers/affine.dml") as fc

conv3x3_forward = function(matrix[double] X, matrix[double] filter,
int C_in, int C_out, int Hin, int Win,
int strideh, int stridew)
return(matrix[double] out, int Hout, int Wout) {
/*
* Simple 2D conv layer with 3x3 filter
*/
# bias should not be applied
bias = matrix(0, C_out, 1)
[out, Hout, Wout] = conv2d::forward(X, filter, bias, C_in, Hin, Win,
3, 3, strideh, stridew, 1, 1)
}

conv1x1_forward = function(matrix[double] X, matrix[double] filter,
int C_in, int C_out, int Hin, int Win,
int strideh, int stridew)
return(matrix[double] out, int Hout, int Wout) {
/*
* Simple 2D conv layer with 1x1 filter
*/
# bias should not be applied
bias = matrix(0, C_out, 1)
[out, Hout, Wout] = conv2d::forward(X, filter, bias, C_in, Hin, Win,
1, 1, strideh, stridew, 0, 0)
}

basic_block_forward = function(matrix[double] X, list[unknown] weights,
int C_in, int C_base, int Hin, int Win,
int strideh, int stridew, string mode,
list[unknown] ema_means_vars)
return (matrix[double] out, int Hout, int Wout,
list[unknown] ema_means_vars_upd) {
/*
* Computes the forward pass for a basic residual block.
* This basic residual block (with 2 3x3 conv layers of
* same channel size) is used in the smaller ResNets 18
* and 34.
*
* Inputs:
* - X: Inputs, of shape (N, C_in*Hin*Win).
* - weights: list of weights for all layers of res block
* with the following order/content:
* -> 1: Weights of conv 1, of shape (C_base, C_in*3*3).
* -> 2: Weights of batch norm 1, of shape (C_base, 1).
* -> 3: Bias of batch norm 1, of shape (C_base, 1).
* -> 4: Weights of conv 2, of shape (C_base, C_base*3*3).
* -> 5: Weights of batch norm 2, of shape (C_base, 1).
* -> 6: Bias of batch norm 2, of shape (C_base, 1).
* If the block should downsample X:
* -> 7: Weights of downsample conv, of shape (C_base, C_in*3*3).
* -> 8: Weights of downsample batch norm, of shape (C_base, 1).
* -> 9: Bias of downsample batch norm, of shape (C_base, 1).
* - C_in: Number of input channels.
* - C_base: Number of base channels for this block.
* - Hin: Input height.
* - Win: Input width.
* - strideh: Stride over height (usually 1 or 2)..
* - stridew: Stride over width (usually same as strideh).
* - mode: 'train' or 'test' to indicate if the model is currently
* being trained or tested for badge normalization layers.
* See badge_norm2d.dml docs for more info.
* - ema_means_vars: List of exponential moving averages for mean
* and variance for badge normalization layers.
* -> 1: EMA for mean of badge norm 1, of shape (C_base, 1).
* -> 2: EMA for variance of badge norm 1, of shape (C_base, 1).
* -> 3: EMA for mean of badge norm 2, of shape (C_base, 1).
* -> 4: EMA for variance of badge norm 2, of shape (C_base, 1).
* If the block should downsample X:
* -> 5: EMA for mean of downs. badge norm, of shape (C_base, 1).
* -> 6: EMA for variance of downs. badge norm, of shape (C_base, 1).
*
* Outputs:
* - out: Output, of shape (N, C_base*Hout*Wout).
* - Hout: Output height.
* - Wout: Output width.
* - ema_means_vars_upd: List of updated exponential moving averages
* for mean and variance of badge normalization layers.
*/
downsample = strideh > 1 | stridew > 1 | C_in != C_base
# default values
mu_bn = 0.1
epsilon_bn = 1e-05

# get all params
W_conv1 = as.matrix(weights[1])
gamma_bn1 = as.matrix(weights[2])
beta_bn1 = as.matrix(weights[3])
W_conv2 = as.matrix(weights[4])
gamma_bn2 = as.matrix(weights[5])
beta_bn2 = as.matrix(weights[6])

ema_mean_bn1 = as.matrix(ema_means_vars[1])
ema_var_bn1 = as.matrix(ema_means_vars[2])
ema_mean_bn2 = as.matrix(ema_means_vars[3])
ema_var_bn2 = as.matrix(ema_means_vars[4])

if (downsample) {
# gather params for donwsampling
W_conv3 = as.matrix(weights[7])
gamma_bn3 = as.matrix(weights[8])
beta_bn3 = as.matrix(weights[9])
ema_mean_bn3 = as.matrix(ema_means_vars[5])
ema_var_bn3 = as.matrix(ema_means_vars[6])
}

# RESIDUAL PATH
# First convolutional layer
[out, Hout, Wout] = conv3x3_forward(X, W_conv1, C_in, C_base, Hin, Win,
strideh, stridew)
[out, ema_mean_bn1_upd, ema_var_bn1_upd, c_m, c_v] = bn2d::forward(out, gamma_bn1,
beta_bn1, C_base, Hout, Wout,
mode, ema_mean_bn1, ema_var_bn1,
mu_bn, epsilon_bn)
out = relu::forward(out)

# Second convolutional layer
[out, Hout, Wout] = conv3x3_forward(out, W_conv2, C_base, C_base, Hout,
Wout, 1, 1)
[out, ema_mean_bn2_upd, ema_var_bn2_upd, c_m, c_v] = bn2d::forward(out, gamma_bn2,
beta_bn2, C_base, Hout, Wout,
mode, ema_mean_bn2, ema_var_bn2,
mu_bn, epsilon_bn)

# IDENTITY PATH
identity = X
if (downsample) {
# downsample input
[identity, Hout, Wout] = conv1x1_forward(X, W_conv3, C_in, C_base,
Hin, Win, strideh, stridew)
[identity, ema_mean_bn3_upd, ema_var_bn3_upd, c_m, c_v] = bn2d::forward(identity,
gamma_bn3, beta_bn3, C_base, Hout, Wout,
mode, ema_mean_bn3, ema_var_bn3, mu_bn,
epsilon_bn)
}

out = relu::forward(out + identity)

ema_means_vars_upd = list(ema_mean_bn1_upd, ema_var_bn1_upd, ema_mean_bn2_upd, ema_var_bn2_upd)
if (downsample) {
ema_means_vars_upd = append(ema_means_vars, ema_mean_bn3_upd)
ema_means_vars_upd = append(ema_means_vars, ema_var_bn3_upd)
}
}

basic_reslayer_forward = function(matrix[double] X, int Hin, int Win, int blocks,
int strideh, int stridew, int C_in, int C_base,
list[unknown] blocks_weights, string mode,
list[unknown] ema_means_vars)
return (matrix[double] out, int Hout, int Wout,
list[unknown] ema_means_vars_upd) {
/*
* Executes the forward pass for a sequence of residual blocks
* with the same number of base channels, i.e. residual layer.
*
* Inputs:
* - X: Inputs, of shape (N, C_in*Hin*Win)
* - Hin: Input height.
* - Win: Input width.
* - blocks: Number of residual blocks (bigger than 0).
* - strideh: Stride height for first conv layer of first block.
* - stridew: Stride width for first conv layer of first block.
* - C_in: Number of input channels.
* - C_base: Number of base channels of res layer.
* - blocks_weights: List of weights of each block.
* -> i: List of weights of block i with the content
* defined in the docs of basic_block_forward().
* -> length == blocks
* - mode: 'train' or 'test' to indicate if the model is currently
* being trained or tested for badge normalization layers.
* See badge_norm2d.dml docs for more info.
* - ema_means_vars: List of exponential moving averages for mean
* and variance for badge normalization layers of each block.
* -> i: List of EMAs of block i with the content defined
* in the docs of basic_block_forward().
* -> length == blocks
*/
# default values
mu_bn = 0.1
epsilon_bn = 1e-05

# first block with provided stride
[out, Hout, Wout, emas1_upd] = basic_block_forward(X, as.list(blocks_weights[1]),
C_in, C_base, Hin, Win, strideh, stridew,
mode, as.list(ema_means_vars[1]))
ema_means_vars_upd = list(emas1_upd)

# other block
for (i in 2:blocks) {
current_weights = as.list(blocks_weights[i])
current_emas = as.list(ema_means_vars[i])
[out, Hout, Wout, current_emas_upd] = basic_block_forward(X=out,
weights=current_weights, C_in=C_base, C_base=C_base,
Hin=Hout, Win=Wout, strideh=1, stridew=1, mode=mode,
ema_means_vars=current_emas)
ema_means_vars_upd = append(ema_means_vars_upd, current_emas_upd)
}
}

resnet18_forward = function(matrix[double] X, int Hin, int Win,
list[unknown] model, string mode,
list[unknown] ema_means_vars)
return (matrix[double] out, list[unknown] ema_means_vars_upd) {
/*
* Forward pass of the ResNet 18 model as introduced in
* "Deep Residual Learning for Image Recognition" by
* Kaiming He et. al. and inspired by the PyTorch
* implementation.
*
* Inputs:
* - X: Inputs, of shape (N, C_in*Hin*Win).
* C_in = 3 is expected.
* - Hin: Input height.
* - Win: Input width.
* - model: Weights and bias matrices of the model
* with the following order/content:
* -> 1: Weights of conv 1 7x7, of shape (64, 3*7*7)
* -> 2: Weights of batch norm 1, of shape (64, 1).
* -> 3: Bias of batch norm 1, of shape (64, 1).
* -> 4: List of weights for first residual layer
* with 64 base channels.
* -> 5: List of weights for second residual layer
* with 128 base channels.
* -> 6: List of weights for third residual layer
* with 256 base channels.
* -> 7: List of weights for fourth residual layer
* with 512 base channels.
* List of residual layers 1, 2, 3 & 4 have
* the content/order:
* -> 1: List of weights for first residual
* block.
* -> 2: List of weights for second residual
* block.
* Each list of weights for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
* -> 8: Weights of fully connected layer, of shape (512, 1000)
* -> 9: Bias of fully connected layer, of shape (1, 1000)
* - mode: 'train' or 'test' to indicate if the model is currently
* being trained or tested for badge normalization layers.
* See badge_norm2d.dml docs for more info.
* - ema_means_vars: List of exponential moving averages for mean
* and variance for badge normalization layers.
* -> 1: EMA for mean of badge norm 1, of shape (64, 1).
* -> 2: EMA for variance of badge norm 1, of shape (64, 1).
* -> 3: List of EMA means and vars for residual layer 1.
* -> 4: List of EMA means and vars for residual layer 2.
* -> 5: List of EMA means and vars for residual layer 3.
* -> 6: List of EMA means and vars for residual layer 4.
* Lists for EMAs of layer 1, 2, 3 & 4 must have the
* following order:
* -> 1: List of EMA means and vars for residual block 1.
* -> 2: List of EMA means and vars for residual block 2.
* Each list of EMAs for a residual block
* must follow the same order as defined in
* the documentation of basic_block_forward().
* - NOTICE: The lists of the first blocks for layer 2, 3 and 4
* must include weights and EMAs for 1 extra conv layer
* and a batch norm layer for the downsampling on the
* identity path.
*
* Outputs:
* - out: Outputs, of shape (N, 1000)
* - ema_means_vars_upd: List of updated exponential moving averages
* for mean and variance of badge normalization layers. It follows
* the same exact structure as the input EMAs list.
*/
# default values
mu_bn = 0.1
epsilon_bn = 1e-05

# extract model params
W_conv1 = as.matrix(model[1])
gamma_bn1 = as.matrix(model[2]); beta_bn1 = as.matrix(model[3])
weights_reslayer1 = as.list(model[4])
weights_reslayer2 = as.list(model[5])
weights_reslayer3 = as.list(model[6])
weights_reslayer4 = as.list(model[7])
W_fc = as.matrix(model[8])
b_fc = as.matrix(model[9])
ema_mean_bn1 = as.matrix(ema_means_vars[1]); ema_var_bn1 = as.matrix(ema_means_vars[2])
emas_reslayer1 = as.list(ema_means_vars[3])
emas_reslayer2 = as.list(ema_means_vars[4])
emas_reslayer3 = as.list(ema_means_vars[5])
emas_reslayer4 = as.list(ema_means_vars[6])

# Convolutional 7x7 layer
C = 64
b_conv1 = matrix(0, rows=C, cols=1)
[out, Hout, Wout] = conv2d::forward(X=X, W=W_conv1, b=b_conv1, C=3,
Hin=Hin, Win=Win, Hf=7, Wf=7, strideh=2,
stridew=2, padh=3, padw=3)
# Batch Normalization
[out, ema_mean_bn1_upd, ema_var_bn1_upd, c_mean, c_var] = bn2d::forward(X=out,
gamma=gamma_bn1, beta=beta_bn1, C=C, Hin=Hout,
Win=Wout, mode=mode, ema_mean=ema_mean_bn1,
ema_var=ema_var_bn1, mu=mu_bn,
epsilon=epsilon_bn)
# ReLU
out = relu::forward(X=out)
# Max Pooling 3x3
[out, Hout, Wout] = mp2d::forward(X=out, C=C, Hin=Hout, Win=Wout, Hf=3,
Wf=3, strideh=2, stridew=2, padh=1, padw=1)

# residual layer 1
[out, Hout, Wout, emas1_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=1, stridew=1, C_in=C,
C_base=64, blocks_weights=weights_reslayer1,
mode=mode, ema_means_vars=emas_reslayer1)
C = 64
# residual layer 2
[out, Hout, Wout, emas2_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C,
C_base=128, blocks_weights=weights_reslayer2,
mode=mode, ema_means_vars=emas_reslayer2)
C = 128
# residual layer 3
[out, Hout, Wout, emas3_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C,
C_base=256, blocks_weights=weights_reslayer3,
mode=mode, ema_means_vars=emas_reslayer3)
C = 256
# residual layer 4
[out, Hout, Wout, emas4_upd] = basic_reslayer_forward(X=out, Hin=Hout,
Win=Wout, blocks=2, strideh=2, stridew=2, C_in=C,
C_base=512, blocks_weights=weights_reslayer4,
mode=mode, ema_means_vars=emas_reslayer4)
C = 512

# Global Average Pooling
[out, Hout, Wout] = ap2d::forward(X=out, C=C, Hin=Hout, Win=Wout)
# Affine
out = fc::forward(X=out, W=W_fc, b=b_fc)

ema_means_vars_upd = list(ema_mean_bn1_upd, ema_var_bn1_upd,
emas1_upd, emas2_upd, emas3_upd, emas4_upd)
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ public void logcosh(){
run("logcosh.dml");
}

@Test
public void resnet() {
run("resnet.dml");
}

@Override
protected void run(String name) {
super.run("component/" + name);
Expand Down
Loading

0 comments on commit 9de62d1

Please sign in to comment.