From 025105e2f96978f1a4b69df9d20ab20d223a3a41 Mon Sep 17 00:00:00 2001 From: Ankan Banerjee Date: Thu, 3 Mar 2022 11:28:20 +0530 Subject: [PATCH] Optimized Res-block Fusion without SE (#1678) * misc changes to cudnn backend - replace all cudaMemcpyAsync used for loading weights with cudaMemcpy as source (in CPU memory) could be deleted before the async version of the function actually does the copy. - minor naming/style changes. - add comment explaining what the policy map layer does and how the layout conversion from CHW to HWC works. * fix typo in comment * clang-format * address review comment * Add 320 and 352 channel support for fused SE layer - just add template instantiations. - verified that it works and provides a (very) slight speedup. * Update fp16_kernels.cu * Simpler kernel for res-block fusion without SE - use constant block size of 64, splitting channel dimension also into multiple blocks as needed. - This allows arbitrarily large filter counts without running out of register file. * minor refactoring - allow using res block fusing opt for alternate layers (that don't have SE) even on GPUs that don't have enough shared memory. * minor functional fix * a few more fixes to get correct output hopefully functionally correct now. * fix cudnn backend build - missed the fact that it also uses Res block fusion :-/ * fix build errors * some more fixes * minor cleanup * remove --use_fast_math - as it doesn't improve performance. - some minor cleanup * fix indentation --- src/neural/cuda/common_kernels.cu | 21 ++- src/neural/cuda/cuda_common.h | 4 + src/neural/cuda/fp16_kernels.cu | 173 +++++++++++++------------ src/neural/cuda/layers.cc | 46 +++++-- src/neural/cuda/layers.h | 3 +- src/neural/cuda/network_cuda.cc | 14 +- src/neural/cuda/network_cudnn.cc | 8 +- src/neural/cuda/winograd_helper.inc | 190 +++++++++++++++++++++++++--- 8 files changed, 333 insertions(+), 126 deletions(-) diff --git a/src/neural/cuda/common_kernels.cu b/src/neural/cuda/common_kernels.cu index 78f0e51df4..06f4ae8e4f 100644 --- a/src/neural/cuda/common_kernels.cu +++ b/src/neural/cuda/common_kernels.cu @@ -556,16 +556,22 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, const T* b1, const T* w2, const T* b2, cudaStream_t stream) { // Each thread processes entire chess board - if (C > kMaxResBlockFusingChannels) { + if (use_se == false) { + dim3 grid_dim(DivUp(C, kOpInpTransformBlockSize), N, 1); + OutputTransform_relu_InputTransform_kernel + <<>>(N, C, output, input, + (float*)skip, bias); + } else if (C > kMaxResBlockFusingChannels) { throw Exception( "res block fusing opt not supported for the given data type and no " "of filters\n"); } else { - OutputTransform_SE_relu_InputTransform_kernel <<>>(N, C, se_K, output, input, (float*)skip, bias, w1, b1, w2, b2); } + ReportCUDAErrors(cudaGetLastError()); } @@ -843,6 +849,7 @@ template void OutputTransform( const float* w2, const float* b2, cudaStream_t stream); template void OutputTransform( + int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); @@ -867,6 +874,11 @@ template void OutputTransform( const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); +template void OutputTransform( + int N, int C, int se_K, float* output, const float* input, + const float* skip, const float* bias, const float* w1, const float* b1, + const float* w2, const float* b2, cudaStream_t stream); + template void OutputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, @@ -897,6 +909,11 @@ template void OutputTransform( const float* skip, const float* bias, const float* w1, const float* b1, const float* w2, const float* b2, cudaStream_t stream); +template void OutputTransform( + int N, int C, int se_K, float* output, const float* input, + const float* skip, const float* bias, const float* w1, const float* b1, + const float* w2, const float* b2, cudaStream_t stream); + template void OutputTransform( int N, int C, int se_K, float* output, const float* input, const float* skip, const float* bias, const float* w1, const float* b1, diff --git a/src/neural/cuda/cuda_common.h b/src/neural/cuda/cuda_common.h index f204423eb9..759238cd4e 100644 --- a/src/neural/cuda/cuda_common.h +++ b/src/neural/cuda/cuda_common.h @@ -55,6 +55,10 @@ static constexpr int kMaxResBlockFusingSeKFp16Ampere = 512; // (use a different kernel with reduced register pressure) static constexpr int kMaxResBlockFusingSeK = 128; // limit on (num_filters / se_ratio) +static constexpr int kMaxResBlockFusingSeFp16AmpereSmem = + 72 * kMaxResBlockFusingSeKFp16Ampere * + sizeof(half); // shared memory used by the special + // kernel #ifdef USE_CUDNN void CudnnError(cudnnStatus_t status, const char* file, const int& line); diff --git a/src/neural/cuda/fp16_kernels.cu b/src/neural/cuda/fp16_kernels.cu index aebe702082..5f59a68973 100644 --- a/src/neural/cuda/fp16_kernels.cu +++ b/src/neural/cuda/fp16_kernels.cu @@ -207,7 +207,7 @@ bool Se_Fp16_NHWC(int N, int C, int numFc1Out, half* output, const half* skip, // 'C' threads per block // 'N' blocks // Every thread generates an entire board/plane (8x8 elements). -template __global__ __launch_bounds__(kMaxResBlockFusingSeKFp16Ampere,1) void OutputInputTransformKernel_fp16_shmem_board( @@ -248,105 +248,97 @@ void OutputInputTransformKernel_fp16_shmem_board( float S = 0; float B = 0; - if (use_bias || use_se) { #pragma unroll - for (int y = 0; y < 8; y++) { - half boardRow[8]; - copyAs(&boardRow, &BOARD(y, 0)); + for (int y = 0; y < 8; y++) { + half boardRow[8]; + copyAs(&boardRow, &BOARD(y, 0)); #pragma unroll - for (int x = 0; x < 8; x++) { - if (use_bias) boardRow[x] += b; - if (use_se) S += (float)boardRow[x]; - } - if (use_bias) copyAs(&BOARD(y, 0), &boardRow); + for (int x = 0; x < 8; x++) { + if (use_bias) boardRow[x] += b; + S += (float)boardRow[x]; } + if (use_bias) copyAs(&BOARD(y, 0), &boardRow); } - if (use_se) { - __shared__ float shared_data[kMaxResBlockFusingSeKFp16Ampere]; - float avg = S / 64; - shared_data[k] = avg; + __shared__ float shared_data[kMaxResBlockFusingSeKFp16Ampere]; + float avg = S / 64; + shared_data[k] = avg; - int lane = k & 0x1F; - int warp = k >> 5; - __syncthreads(); + int lane = k & 0x1F; + int warp = k >> 5; + __syncthreads(); - // First fully-connected layer for SE + // First fully-connected layer for SE - // As se_K << C, we want to loop over se_K instead of C - // even if it means taking the sum across threads + // As se_K << C, we want to loop over se_K instead of C + // even if it means taking the sum across threads - __shared__ float shared_sums[kMaxResBlockFusingSeKFp16Ampere / 32] - [kMaxResBlockFusingSeK]; // per-warp sums + __shared__ float shared_sums[kMaxResBlockFusingSeKFp16Ampere / 32] + [kMaxResBlockFusingSeK]; // per-warp sums - for (int i = 0; i < se_K; i++) { - float val = shared_data[k] * float(readw1(k, i)); - val = warpReduce(val); - if (lane == 0) shared_sums[warp][i] = val; - } - __syncthreads(); - if (k < se_K) { - S = 0; - for (int i = 0; i < C / 32; i++) S += shared_sums[i][k]; - - S += (float)b1[k]; - S = activate(S, activation); - shared_data[k] = S; - } + for (int i = 0; i < se_K; i++) { + float val = shared_data[k] * float(readw1(k, i)); + val = warpReduce(val); + if (lane == 0) shared_sums[warp][i] = val; + } + __syncthreads(); + if (k < se_K) { + S = 0; + for (int i = 0; i < C / 32; i++) S += shared_sums[i][k]; - __syncthreads(); + S += (float)b1[k]; + S = activate(S, activation); + shared_data[k] = S; + } - // Second fully-connected layer for SE - S = 0; - for (int i = 0; i < se_K; i++) { - float val = shared_data[i]; - S += val * float(readw2(i, k)); - B += val * float(readw2(i, k + C)); - } - S += (float)b2[k]; - B += (float)b2[k + C]; + __syncthreads(); - // Sigmoid (only on the scale part). - S = 1.0f / (1.0f + exp(-S)); + // Second fully-connected layer for SE + S = 0; + for (int i = 0; i < se_K; i++) { + float val = shared_data[i]; + S += val * float(readw2(i, k)); + B += val * float(readw2(i, k + C)); } + S += (float)b2[k]; + B += (float)b2[k + C]; - // Scale/bias, add skip connection, perform relu, and write to output. - if (use_se || use_skip || activation != NONE) { - for (int h = 0; h < 8; h++) { - half boardRow[8]; - copyAs(&boardRow[0], &BOARD(h, 0)); + // Sigmoid (only on the scale part). + S = 1.0f / (1.0f + exp(-S)); + + // Scale/bias, add skip connection, perform activation, and write to output. + for (int h = 0; h < 8; h++) { + half boardRow[8]; + copyAs(&boardRow[0], &BOARD(h, 0)); - if (use_se) { -#pragma unroll - for (int w = 0; w < 8; w++) { - boardRow[w] = (half)(float(boardRow[w]) * S + B); - } - } - - // residual add - if (use_skip) { - half skipInp[8]; - copyAs(&skipInp[0], &skip[INDEX_NHCW(n, k, h, 0)]); #pragma unroll - for (int w = 0; w < 8; w++) boardRow[w] += skipInp[w]; - } + for (int w = 0; w < 8; w++) { + boardRow[w] = (half)(float(boardRow[w]) * S + B); + } - // relu - if (activation != NONE) { + // residual add + if (use_skip) { + half skipInp[8]; + copyAs(&skipInp[0], &skip[INDEX_NHCW(n, k, h, 0)]); #pragma unroll - for (int w = 0; w < 8; w++) - boardRow[w] = (half)activate((float)boardRow[w], activation); - } + for (int w = 0; w < 8; w++) boardRow[w] += skipInp[w]; + } - // write un-transformed output to 'skip' if required - if (use_skip) { - copyAs(&skip[INDEX_NHCW(n, k, h, 0)], &boardRow[0]); - } + if (activation != NONE) { +#pragma unroll + for (int w = 0; w < 8; w++) + boardRow[w] = (half)activate((float)boardRow[w], activation); + } - copyAs(&BOARD(h, 0), &boardRow); + // write un-transformed output to 'skip' if required + if (use_skip) { + copyAs(&skip[INDEX_NHCW(n, k, h, 0)], &boardRow[0]); } + + copyAs(&BOARD(h, 0), &boardRow); } + // Perform input transform. int c = k; @@ -434,17 +426,24 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, const T* b1, const T* w2, const T* b2, cudaStream_t stream) { // Each thread processes entire chess board. - if (C > kMaxResBlockFusingChannels) { + if (use_se == false) { + dim3 grid_dim(DivUp(C, kOpInpTransformBlockSize), N, 1); + OutputTransform_relu_InputTransform_kernel + <<>>(N, C, output, input, + (half*)skip, bias); + } else if (C > kMaxResBlockFusingChannels) { // Use special kernel with reduced register pressure - only works on Ampere, // and only for fp16. if (C <= kMaxResBlockFusingSeKFp16Ampere) { cudaFuncSetAttribute( - OutputInputTransformKernel_fp16_shmem_board, - cudaFuncAttributeMaxDynamicSharedMemorySize, 72 * 1024); - OutputInputTransformKernel_fp16_shmem_board - <<>>( + <<>>( N, C, se_K, (half*)output, (const half*)input, (half*)skip, (half*)bias, (half*)w1, (half*)b1, (half*)w2, (half*)b2); } else { @@ -453,7 +452,7 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input, "of filters\n"); } } else { - OutputTransform_SE_relu_InputTransform_kernel <<>>(N, C, se_K, output, input, (half*)skip, bias, w1, b1, w2, b2); @@ -501,6 +500,11 @@ template void OutputTransform( const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); +template void OutputTransform( + int N, int C, int se_K, half* output, const half* input, const half* skip, + const half* bias, const half* w1, const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + template void OutputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, @@ -531,6 +535,11 @@ template void OutputTransform( const half* bias, const half* w1, const half* b1, const half* w2, const half* b2, cudaStream_t stream); +template void OutputTransform( + int N, int C, int se_K, half* output, const half* input, const half* skip, + const half* bias, const half* w1, const half* b1, const half* w2, + const half* b2, cudaStream_t stream); + template void OutputTransform( int N, int C, int se_K, half* output, const half* input, const half* skip, const half* bias, const half* w1, const half* b1, const half* w2, diff --git a/src/neural/cuda/layers.cc b/src/neural/cuda/layers.cc index 1b1a139f41..58ccc35c03 100644 --- a/src/neural/cuda/layers.cc +++ b/src/neural/cuda/layers.cc @@ -1055,13 +1055,15 @@ Conv1Layer::~Conv1Layer() { template ResidualBlock::ResidualBlock(BaseLayer* ip, int C, bool se, int se_k, bool use_gemm_ex, bool first, - bool last, ActivationFunction activation) + + bool last, ActivationFunction activation, int shared_mem_size) : BaseLayer(C, 8, 8, ip, ip->isNHWC(), use_gemm_ex), has_se_(se), se_k_(se_k), c_input_(C), first_block_(first), last_block_(last), + shared_mem_size_(shared_mem_size), act_(activation) { if (act_ != RELU && act_ != MISH) { throw Exception("Unsupported activation for residual block."); @@ -1229,6 +1231,12 @@ void ResidualBlock::Eval(int N, DataType* output, transformed_input, transformed_weights1_, transformed_output, N * 4, C, C, 36, cublas); + const bool fp16 = std::is_same::value; + bool allowFusing = + (C <= kMaxResBlockFusingChannels) || + (fp16 && (shared_mem_size_ >= kMaxResBlockFusingSeFp16AmpereSmem) && + (C <= kMaxResBlockFusingSeKFp16Ampere)); + if (act_ == RELU) { if (last_block_) { if (has_se_) @@ -1240,11 +1248,19 @@ void ResidualBlock::Eval(int N, DataType* output, N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); } else { - if (has_se_) - OutputInputTransform( - N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, - w2_, b2_, stream); - else + if (has_se_) { + if (allowFusing) { + OutputInputTransform( + N, C, se_k_, output, transformed_output, input, biases1_, w1_, + b1_, w2_, b2_, stream); + } else { + OutputTransform( + N, C, se_k_, (DataType*)input, transformed_output, input, + biases1_, w1_, b1_, w2_, b2_, stream); + InputTransform(N, C, output, (DataType*)input, + stream); + } + } else OutputInputTransform( N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); @@ -1260,11 +1276,19 @@ void ResidualBlock::Eval(int N, DataType* output, N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); } else { - if (has_se_) - OutputInputTransform( - N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, - w2_, b2_, stream); - else + if (has_se_) { + if (allowFusing) { + OutputInputTransform( + N, C, se_k_, output, transformed_output, input, biases1_, w1_, + b1_, w2_, b2_, stream); + } else { + OutputTransform( + N, C, se_k_, (DataType*)input, transformed_output, input, + biases1_, w1_, b1_, w2_, b2_, stream); + InputTransform(N, C, output, (DataType*)input, + stream); + } + } else OutputInputTransform( N, C, se_k_, output, transformed_output, input, biases1_, w1_, b1_, w2_, b2_, stream); diff --git a/src/neural/cuda/layers.h b/src/neural/cuda/layers.h index 6b8409d24b..c1059b6326 100644 --- a/src/neural/cuda/layers.h +++ b/src/neural/cuda/layers.h @@ -299,7 +299,7 @@ class ResidualBlock : public BaseLayer { public: ResidualBlock(BaseLayer* ip, int C, bool se, int se_k, bool use_gemm_ex, bool first, bool last, - ActivationFunction activation); + ActivationFunction activation, int shared_mem_size); ~ResidualBlock(); void LoadWeights0(float* pfilter, float* pBias, void* scratch); @@ -317,6 +317,7 @@ class ResidualBlock : public BaseLayer { const int c_input_; const bool first_block_; const bool last_block_; + const int shared_mem_size_; const ActivationFunction act_; DataType* biases0_ = nullptr; diff --git a/src/neural/cuda/network_cuda.cc b/src/neural/cuda/network_cuda.cc index 9090f4c33b..d7c1122bac 100644 --- a/src/neural/cuda/network_cuda.cc +++ b/src/neural/cuda/network_cuda.cc @@ -237,14 +237,12 @@ class CudaNetwork : public Network { "using a smaller network."; } - // Disable res block fusing for > 512 filters (the fused output input - // transform kernel runs out of register space) and for fp32 for now. + // Disable res block fusing for fp32 for now (not worth it) // TODO: make it work for filters not a multiple of 32. - if ((kNumFilters <= kMaxResBlockFusingChannels || - ((deviceProp.major >= 8 || - (deviceProp.major == 7 && deviceProp.minor != 5)) && - kNumFilters <= kMaxResBlockFusingSeKFp16Ampere)) && - kNumFilters % 32 == 0 && fp16) { + // Note that when used with SE, the optimization + // works only when filter count is <= 384 (pre-Ampere), or less than 512 (Ampere) + // It turns dynamically off based on filter count (see ResidualBlock::Eval) + if (kNumFilters % 32 == 0 && std::is_same::value) { use_res_block_winograd_fuse_opt_ = true; } else { use_res_block_winograd_fuse_opt_ = false; @@ -312,7 +310,7 @@ class CudaNetwork : public Network { if (use_res_block_winograd_fuse_opt_) { auto layer = std::make_unique>( getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, block == 0, - block == (numBlocks_ - 1), mish_net ? MISH : RELU); + block == (numBlocks_ - 1), mish_net ? MISH : RELU, deviceProp.sharedMemPerBlockOptin); layer->LoadWeights0(&weights.residual[block].conv1.weights[0], &weights.residual[block].conv1.biases[0], scratch_mem_); diff --git a/src/neural/cuda/network_cudnn.cc b/src/neural/cuda/network_cudnn.cc index 476eb4f125..4a0e3016ca 100644 --- a/src/neural/cuda/network_cudnn.cc +++ b/src/neural/cuda/network_cudnn.cc @@ -296,10 +296,9 @@ class CudnnNetwork : public Network { use_res_block_winograd_fuse_opt_ = false; if (use_custom_winograd_) { - // Disable res block fusing for > 384 filters (the fused output input - // transform kernel runs out of register space) and for fp32 for now. + // Disable res block fusing for fp32 for now. // TODO: make it work for filters not a multiple of 32. - if (kNumFilters <= 384 && kNumFilters % 32 == 0 && fp16) { + if (kNumFilters % 32 == 0 && fp16) { use_res_block_winograd_fuse_opt_ = true; } // Override if set in backend-opts. @@ -413,7 +412,8 @@ class CudnnNetwork : public Network { if (use_res_block_winograd_fuse_opt_) { auto layer = std::make_unique>( getLastLayer(), kNumFilters, has_se, se_k, use_gemm_ex, - block == 0, block == (numBlocks_ - 1), mish_net ? MISH : RELU); + block == 0, block == (numBlocks_ - 1), mish_net ? MISH : RELU, + deviceProp.sharedMemPerBlockOptin); layer->LoadWeights0(&weights.residual[block].conv1.weights[0], &weights.residual[block].conv1.biases[0], scratch_mem_); diff --git a/src/neural/cuda/winograd_helper.inc b/src/neural/cuda/winograd_helper.inc index 48189ca9ca..456649ba87 100644 --- a/src/neural/cuda/winograd_helper.inc +++ b/src/neural/cuda/winograd_helper.inc @@ -30,11 +30,12 @@ namespace cudnn_backend { __device__ __forceinline__ float mishActivate(float el) { auto e = __expf(el); - auto n = e * e + 2 * e; + auto n = e * e + 2.0f * e; + auto d = __fdividef(el, n + 2.0f); if (el <= -0.6f) { - return n * __fdividef(el, n + 2); + return n * d; } else { - return el - 2 * __fdividef(el, n + 2); + return el - 2.0f * d; } } __device__ __forceinline__ float activate(float cVal, @@ -47,14 +48,14 @@ __device__ __forceinline__ float activate(float cVal, cVal = tanh(cVal); break; case SIGMOID: - cVal = 1.0f / (1.0f + exp(-cVal)); + cVal = 1.0f / (1.0f + __expf(-cVal)); break; case SELU: { float alpha = 1.67326324f, scale = 1.05070098f; if (cVal > 0) cVal = scale * cVal; else - cVal = scale * alpha * (exp(cVal) - 1); + cVal = scale * alpha * (__expf(cVal) - 1.0f); break; } case MISH: @@ -430,7 +431,7 @@ __device__ __forceinline__ void copyAs(void* dst, const void* src) { // 'C' threads per block // 'N' blocks // every thread generates an entire board/plane (8x8 elements) -template __global__ __launch_bounds__(kMaxResBlockFusingChannels, 1) void OutputTransform_SE_relu_InputTransform_kernel( @@ -444,13 +445,6 @@ void OutputTransform_SE_relu_InputTransform_kernel( T board[8][8]; T b = bias[k]; - T skipInp[8][8]; -#pragma unroll - for (int h = 0; h < 8; h++) { - copyAs(&skipInp[h][0], &skip[INDEX_NHCW(n, k, h, 0)]); - if (!fp16) copyAs(&skipInp[h][4], &skip[INDEX_NHCW(n, k, h, 4)]); - } - #pragma unroll for (int hStart = 0; hStart < 8; hStart += 4) #pragma unroll @@ -483,10 +477,10 @@ void OutputTransform_SE_relu_InputTransform_kernel( #pragma unroll for (int x = 0; x < 8; x++) { if (use_bias) board[y][x] += b; - if (use_se) S += (float)board[y][x]; + S += (float)board[y][x]; } - if (use_se) { + { __shared__ float shared_data[kMaxResBlockFusingChannels]; float avg = S / 64; shared_data[k] = avg; @@ -536,14 +530,16 @@ void OutputTransform_SE_relu_InputTransform_kernel( // Scale/bias, add skip connection, perform relu, and write to output. for (int h = 0; h < 8; h++) { - if (use_se) #pragma unroll - for (int w = 0; w < 8; w++) board[h][w] = (T)(float(board[h][w]) * S + B); + for (int w = 0; w < 8; w++) board[h][w] = (T)(float(board[h][w]) * S + B); // residual add if (use_skip) { + T skipInp[8]; + copyAs(&skipInp[0], &skip[INDEX_NHCW(n, k, h, 0)]); + if (!fp16) copyAs(&skipInp[4], &skip[INDEX_NHCW(n, k, h, 4)]); #pragma unroll - for (int w = 0; w < 8; w++) board[h][w] += skipInp[h][w]; + for (int w = 0; w < 8; w++) board[h][w] += skipInp[w]; } // relu @@ -641,6 +637,164 @@ void OutputTransform_SE_relu_InputTransform_kernel( } } + +constexpr int kOpInpTransformBlockSize = 64; +template +__global__ __launch_bounds__(kOpInpTransformBlockSize, 4) +void OutputTransform_relu_InputTransform_kernel(int N, int C, + T* output, const T* input, + T* skip, const T* bias) { + const bool fp16 = std::is_same::value; + + int k = threadIdx.x + blockIdx.x * kOpInpTransformBlockSize; + if (k >= C) return; // wasted threads (for non-multiple of 64 channel counts) + int n = blockIdx.y; + + T board[8][8]; + T b = bias[k]; + + T skipInp[8][8]; +#pragma unroll + for (int h = 0; h < 8; h++) { + copyAs(&skipInp[h][0], &skip[INDEX_NHCW(n, k, h, 0)]); + if (!fp16) copyAs(&skipInp[h][4], &skip[INDEX_NHCW(n, k, h, 4)]); + } + +#pragma unroll + for (int hStart = 0; hStart < 8; hStart += 4) +#pragma unroll + for (int wStart = 0; wStart < 8; wStart += 4) { + // i) read to per thread registers (for doing output transform) + int shln = n * 4 + (hStart / 4) * 2 + (wStart / 4); + T outElTransformed[6][6]; +#pragma unroll + for (int y = 0; y < 6; y++) +#pragma unroll + for (int x = 0; x < 6; x++) + outElTransformed[y][x] = input[TEMP_INDEX_HWNC(y, x, shln, k)]; + + // ii) transform it + T outEl[4][4]; + OutputTransform4x4(&outEl[0][0], &outElTransformed[0][0]); + +#pragma unroll + for (int y = 0; y < 4; y++) +#pragma unroll + for (int x = 0; x < 4; x++) board[hStart + y][wStart + x] = outEl[y][x]; + } + + // Add bias +#pragma unroll + for (int y = 0; y < 8; y++) +#pragma unroll + for (int x = 0; x < 8; x++) + if (use_bias) board[y][x] += b; + + + // Add skip connection, perform relu, and write to output. + for (int h = 0; h < 8; h++) { + // residual add + if (use_skip) { +#pragma unroll + for (int w = 0; w < 8; w++) board[h][w] += skipInp[h][w]; + } + + // activation + if (activation != NONE) { +#pragma unroll + for (int w = 0; w < 8; w++) + board[h][w] = (T) activate((float)board[h][w], activation); + } + + // write un-transformed output to 'skip' if required + if (use_skip) { + // Write to skip (use 128 bit writes to store one row a time) + copyAs(&skip[INDEX_NHCW(n, k, h, 0)], &board[h][0]); + if (!fp16) copyAs(&skip[INDEX_NHCW(n, k, h, 4)], &board[h][4]); + } + } + + // perform input transform + + int c = k; + // top-left + { + T inEl[6][6] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 5; i++) +#pragma unroll + for (int j = 0; j < 5; j++) inEl[i + 1][j + 1] = board[i][j]; + + InputTransform4x4(&inEl[0][0], &inEl[0][0]); + +#pragma unroll + for (int y = 0; y < 6; y++) +#pragma unroll + for (int x = 0; x < 6; x++) + output[TEMP_INDEX_HWNC(y, x, n * 4 + 0, c)] = inEl[y][x]; + } + + // top-right + { + T inEl[6][6] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 5; i++) +#pragma unroll + for (int j = 0; j < 5; j++) inEl[i + 1][j] = board[i][j + 3]; + + InputTransform4x4(&inEl[0][0], &inEl[0][0]); + +#pragma unroll + for (int y = 0; y < 6; y++) +#pragma unroll + for (int x = 0; x < 6; x++) + output[TEMP_INDEX_HWNC(y, x, n * 4 + 1, c)] = inEl[y][x]; + } + + // bottom-left + { + T inEl[6][6] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 5; i++) +#pragma unroll + for (int j = 0; j < 5; j++) inEl[i][j + 1] = board[i + 3][j]; + + InputTransform4x4(&inEl[0][0], &inEl[0][0]); + +#pragma unroll + for (int y = 0; y < 6; y++) +#pragma unroll + for (int x = 0; x < 6; x++) + output[TEMP_INDEX_HWNC(y, x, n * 4 + 2, c)] = inEl[y][x]; + } + + // bottom-right + { + T inEl[6][6] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 5; i++) +#pragma unroll + for (int j = 0; j < 5; j++) inEl[i][j] = board[i + 3][j + 3]; + + InputTransform4x4(&inEl[0][0], &inEl[0][0]); + +#pragma unroll + for (int y = 0; y < 6; y++) +#pragma unroll + for (int x = 0; x < 6; x++) + output[TEMP_INDEX_HWNC(y, x, n * 4 + 3, c)] = inEl[y][x]; + } +} + + template void FilterTransform(int N, int C, T* transformedFilter, const T* filter) { // Each thread processes entire filter block (input 3x3 elements -> output 6x6