From cfb86c4e23d424328066fe8d2fbbacb9c9ead6c1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 13 Aug 2017 09:30:41 +0800 Subject: [PATCH 01/20] Add vol2col and col2vol cuda kernel --- paddle/cuda/include/hl_matrix.h | 58 ++++++++++ paddle/cuda/include/stub/hl_matrix_stub.h | 15 +++ paddle/cuda/src/hl_cuda_matrix.cu | 135 ++++++++++++++++++++++ 3 files changed, 208 insertions(+) diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index eb454c59c1e58..da2ed8cabb766 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -224,4 +224,62 @@ extern void hl_matrix_collect_shared_bias(real* B_d, extern void hl_matrix_rotate( real* mat, real* matRot, int dimM, int dimN, bool clockWise); +/** + * @brief Matrix vol2Col: Convert 3D volume into col matrix + * + * @param[in] matSrc input matrix. + * @param[in] channel channel of matSrc. + * @param[in] depth depth of matSrc. + * @param[in] height height of matSrc. + * @param[in] width width of matSrc. + * @param[in] filterD depth of filter. + * @param[in] filterH height of filter. + * @param[in] filterW width of filter. + * @param[in] strideD stride in the depth. + * @param[in] strideH stride in the height. + * @param[in] strideW stride in the width. + * @param[in] paddingD padding in the depth. + * @param[in] paddingH padding in the height. + * @param[in] paddingW padding in the width. + * @param[out] matDst output matrix. + * + */ +extern void hl_matrix_vol2Col(real* matSrc, + int channel, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + real* matDst); + +/** + * @brief Matrix col2Vol: Convert col matrix into 3D volume + * + * @param[out] matDst output matrix. + * @param[in] channel channel of matDst. + * @param[in] depth depth of matDst. + * @param[in] height height of matDst. + * @param[in] width width of matDst. + * @param[in] filterD depth of filter. + * @param[in] filterH height of filter. + * @param[in] filterW width of filter. + * @param[in] strideD stride in the depth. + * @param[in] strideH stride in the height. + * @param[in] strideW stride in the width. + * @param[in] paddingD padding in the depth. + * @param[in] paddingH padding in the height. + * @param[in] paddingW padding in the width. + * @param[in] matSrc input matrix. + * @param[in] beta input + * @param[in] alpha input + * + */ +extern void hl_matrix_col2Vol(real* matDst, + int channels, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + real* matSrc, + real alpha, real beta); + + #endif /* HL_MATRIX_H_ */ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index 127cb7e27983e..0b73777812ae6 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -99,4 +99,19 @@ inline void hl_matrix_collect_shared_bias(real* B_d, inline void hl_matrix_rotate( real* mat, real* matRot, int dimM, int dimN, bool clockWise) {} +inline void hl_matrix_vol2Col(real* data, + int channels, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + real* data_col) {} + +inline void hl_matrix_col2Vol(real* data, + int channels, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + real* data_Im, + real alpha, real beta) {} + #endif // HL_MATRIX_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 39272456c394a..f626c07a0c39a 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -592,3 +592,138 @@ void hl_matrix_rotate( mat, matRot, dimM, dimN, clockWise); CHECK_SYNC("hl_matrix_rotate failed"); } + + +__global__ void keMatrixVol2Col( + int num_kernels, real*dataSrc, real* dataDst, + int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + int depth_col, int height_col, int width_col){ + + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < num_kernels; + index += blockDim.x * gridDim.x){ + + int w_out = index % width_col; + int h_out = (index / width_col ) % height_col; + int d_out = (index / width_col / height_col) % depth_col; + int channel_in = index / width_col / height_col / depth_col; + int channel_out = channel_in * filterD * filterH * filterW; + int w_in = w_out * strideW - paddingW; + int h_in = h_out * strideH - paddingH; + int d_in = d_out * strideD - paddingD; + + dataDst += ((channel_out * depth_col + d_out) * height_col + h_out) * width_col + w_out; + dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in; + for (int k = 0; k < filterD; ++k) { + for (int i = 0; i < filterH; ++i) { + for (int j = 0; j < filterW; ++j) { + int d = d_in + k; + int h = h_in + i; + int w = w_in + j; + *dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width ) ? + dataSrc[(k * height + i) * width + j] : 0; + dataDst += depth_col * height_col * width_col; + } + } + } + } +} + +void hl_matrix_vol2Col(real* dataSrc, + int channels, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, real* dataDst){ + + int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; + int height_col = (height + 2 * paddingH - filterH) / strideH + 1; + int width_col = (width + 2 * paddingW - filterW) / strideW + 1; + int num_kernels = channels * depth_col * height_col * width_col; + + const int threads = 512; + const int blocks = DIVUP(num_kernels, threads); + + keMatrixVol2Col<<< blocks, threads >>>( + num_kernels, dataSrc, dataDst, + depth, height, width, + filterD, filterH, filterW, + strideD, strideH, strideW, + paddingD, paddingH, paddingW, + depth_col, height_col, width_col); + CHECK_SYNC("hl_matrix_vol2Col failed"); +} + +__global__ void keMatrixCol2Vol( + int num_kernels, real*dataDst, real* dataSrc, + int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + int depth_col, int height_col, int width_col, + real alpha, real beta){ + + for (int index = blockIdx.x * blockDim.x + threadIdx.x; + index < num_kernels; + index += blockDim.x * gridDim.x) { + + real val = 0; + int w = index % width + paddingW; + int h = (index / width) % height + paddingH; + int d = (index / width / height) % depth + paddingD; + int c = index / (width * height * depth); + // compute the start and end of the output + int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1; + int w_col_end = min(w / strideW + 1, width_col); + int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1; + int h_col_end = min(h / strideH + 1, height_col); + int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1; + int d_col_end = min(d / strideD + 1, depth_col); + + int offset = (c * filterD * filterW * filterH + \ + d * filterW * filterH + h * filterW + w) * depth_col * height_col * width_col; + + int coeff_d_col = (1 - strideD * filterW * filterH * depth_col) * height_col * width_col; + int coeff_h_col = (1 - strideH * filterW * depth_col * height_col) * width_col; + int coeff_w_col = (1 - strideW * depth_col * height_col * width_col); + + for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + val += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + } + dataDst[index] = val; + } +} + +void hl_matrix_col2Vol(real* dataDst, + int channels, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + real* dataSrc, + real alpha, real beta){ + + int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; + int height_col = (height + 2 * paddingH - filterH) / strideH + 1; + int width_col = (width + 2 * paddingW - filterW) / strideW + 1; + int num_kernels = channels * depth * height * width; + + const int threads = 512; + const int blocks = DIVUP(num_kernels, threads); + + keMatrixCol2Vol<<< blocks, threads >>>( + num_kernels, dataDst, dataSrc, + depth, height, width, + filterD, filterH, filterW, + strideD, strideH, strideW, + paddingD, paddingH, paddingW, + depth_col, height_col, width_col, + alpha, beta); + + CHECK_SYNC("hl_matrix_col2Vol failed"); +} From 8cc0eb9c5d564b71452e65d1bac3f9f19f5bf89e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 13 Aug 2017 09:38:02 +0800 Subject: [PATCH 02/20] Modify ConvConfig, Add depth dimension --- proto/ModelConfig.proto | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 4f3d5bf3f6cb9..043ae502b0221 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -82,6 +82,12 @@ message ConvConfig { // if not set, use img_size optional uint32 img_size_y = 14; + + optional uint32 filter_size_z = 15 [ default = 1 ]; + optional uint32 padding_z = 16 [ default = 1 ]; + optional uint32 stride_z = 17 [ default = 1 ]; + optional uint32 output_z = 18 [ default = 1 ]; + optional uint32 img_size_z = 19 [ default = 1 ]; } message PoolConfig { @@ -631,4 +637,4 @@ message ModelConfig { // For External Machine, defining how to split a neural network // into multiple parts. optional ExternalConfig external_config = 9; -}; +}; \ No newline at end of file From 5d7f6dde52af781e15953c041374b5671bdf918d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 13 Aug 2017 09:42:48 +0800 Subject: [PATCH 03/20] Add depth dimension information to ConvBaseLayer --- paddle/gserver/layers/ConvBaseLayer.cpp | 17 +++++++++++++---- paddle/gserver/layers/ConvBaseLayer.h | 8 ++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index e161d89c38a29..e437b0b86e0de 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -21,9 +21,11 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv") - ? false - : true; + isDeconv_ = (config_.type() == "exconv" || + config_.type() == "cudnn_conv" || + config_.type() == "conv3d" || + config_.type() == "deconv3d" ) + ? false : true; /* Initialize the convolutional layer parameter */ numFilters_ = config_.num_filters(); @@ -36,7 +38,6 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, paddingY_.push_back(conf.padding_y()); strideY_.push_back(conf.stride_y()); filterSizeY_.push_back(conf.filter_size_y()); - filterPixels_.push_back(filterSize_.back() * filterSizeY_.back()); channels_.push_back(conf.channels()); imgSizeH_.push_back(conf.has_img_size_y() ? conf.img_size_y() : conf.img_size()); @@ -45,6 +46,14 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, filterChannels_.push_back(conf.filter_channels()); outputH_.push_back(conf.has_output_y() ? conf.output_y() : conf.output_x()); outputW_.push_back(conf.output_x()); + + paddingZ_.push_back(conf.padding_z()); + strideZ_.push_back(conf.stride_z()); + filterSizeZ_.push_back(conf.filter_size_z()); + imgSizeD_.push_back(conf.img_size_z()); + outputD_.push_back(conf.output_z()); + filterPixels_.push_back( + filterSize_.back() * filterSizeY_.back() * filterSizeZ_.back()); } CHECK(inputLayers_.size() == parameters_.size()); diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h index e9d15d94f806a..8d1fd989e8383 100644 --- a/paddle/gserver/layers/ConvBaseLayer.h +++ b/paddle/gserver/layers/ConvBaseLayer.h @@ -23,6 +23,7 @@ namespace paddle { * with learned filters and (optionally) adds biases. */ + class ConvBaseLayer : public Layer { protected: typedef std::vector IntV; @@ -58,6 +59,13 @@ class ConvBaseLayer : public Layer { IntV outputH_; /// The spatial dimensions of output feature map width. IntV outputW_; + + IntV outputD_; + IntV imgSizeD_; + IntV filterSizeZ_; + IntV strideZ_; + IntV paddingZ_; + /// Group size, refer to grouped convolution in /// Alex Krizhevsky's paper: when group=2, the first half of the /// filters are only connected to the first half of the input channels, From 11975b4f9185907b5f2518722e5311d744361887 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 13 Aug 2017 09:47:37 +0800 Subject: [PATCH 04/20] Add Conv3DLayer --- paddle/gserver/layers/Conv3DLayer.cpp | 225 ++++++++++++++++++++++++++ paddle/gserver/layers/Conv3DLayer.h | 57 +++++++ 2 files changed, 282 insertions(+) create mode 100644 paddle/gserver/layers/Conv3DLayer.cpp create mode 100644 paddle/gserver/layers/Conv3DLayer.h diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp new file mode 100644 index 0000000000000..0fa9c5f9f56df --- /dev/null +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -0,0 +1,225 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" +#include "Conv3DLayer.h" + +namespace paddle { + +REGISTER_LAYER(conv3d, Conv3DLayer); + +bool Conv3DLayer::init(const LayerMap &layerMap, + const ParameterMap ¶meterMap) { + if (!ConvBaseLayer::init(layerMap, parameterMap)) + return false; + int index = 0; + for (auto &inputConfig : config_.inputs()) { + const ConvConfig &conf = inputConfig.conv_conf(); + M_.push_back(numFilters_ / conf.groups()); + K_.push_back( + conf.filter_channels() * conf.filter_size_z() * \ + conf.filter_size_y() * conf.filter_size()); + weights_[index]->getW()->reshape( + weights_[index]->getW()->getWidth(), + weights_[index]->getW()->getHeight()); + weights_[index]->getWGrad()->reshape( + weights_[index]->getWGrad()->getWidth(), + weights_[index]->getWGrad()->getHeight()); + ++index; + } + biases_->getWGrad()->reshape( + biases_->getWGrad()->width_, biases_->getWGrad()->height_); + biases_->getW()->reshape( + biases_->getW()->width_, biases_->getW()->height_); + CHECK(inputLayers_.size() == parameters_.size()); + return true; +} + + +size_t Conv3DLayer::getSize() { + CHECK_NE(inputLayers_.size(), 0UL); + // imgSizeH_.clear(); + // imgSizeW_.clear(); + // imgSizeD_.clear(); + outputH_.clear(); + outputW_.clear(); + outputD_.clear(); + N_.clear(); + size_t layerSize = 0; + for (size_t i = 0; i < inputLayers_.size(); ++i) { + // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); + // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); + // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); + outputW_.push_back(outputSize( + imgSizeW_[i], filterSize_[i], + padding_[i], stride_[i], true)); + outputH_.push_back(outputSize( + imgSizeH_[i], filterSizeY_[i], + paddingY_[i], strideY_[i], true)); + outputD_.push_back(outputSize( + imgSizeD_[i], filterSizeZ_[i], + paddingZ_[i], strideZ_[i], true)); + + N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); + CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); + layerSize += N_[i] * numFilters_; + } + getOutput().setFrameHeight(outputH_[0]); + getOutput().setFrameWidth(outputW_[0]); + getOutput().setFrameDepth(outputD_[0]); + return layerSize; +} + +void Conv3DLayer::forward(PassType passType) { + Layer::forward(passType); + + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + int outWidth = getSize(); + resetOutput(batchSize, outWidth); + const MatrixPtr outMat = getOutputValue(); + + for (size_t i = 0; i != inputLayers_.size(); ++i) { + REGISTER_TIMER_INFO("FwdConv3D", getName().c_str()); + const MatrixPtr& inMat = getInputValue(i); + int width = inMat->getWidth(); + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + MatrixPtr wMat = weights_[i]->getW(); + for (int n = 0; n < batchSize; ++n) { + colBuf_->vol2Col(inMat->getData() + n * width, channels_[i], + imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], + filterSizeZ_[i], filterSizeY_[i], filterSize_[i], + strideZ_[i], strideY_[i], stride_[i], + paddingZ_[i], paddingY_[i], padding_[i]); + + real *outData = outMat->getData() + n * outWidth; + MatrixPtr outMatSub = + Matrix::create(outData, groups_[i] * M, N, false, useGpu_); + for (int g = 0; g < groups_[i]; g++) { + MatrixPtr wMatSub = wMat->subMatrix(g * M, M); + MatrixPtr in = colBuf_->subMatrix(g * K, K); + MatrixPtr out = outMatSub->subMatrix(g * M, M); + out->mul(*wMatSub, *in, 1.0, 0.0); + } + } + } + if (nullptr != this->biasParameter_) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + this->addBias(); + } + forwardActivation(); +} + +void Conv3DLayer::backward(const UpdateCallback &callback) { + backwardActivation(); + + if (biases_ && biases_->getWGrad()) { + bpropBiases(); + biases_->getParameterPtr()->incUpdate(callback); + } + + for (size_t i = 0; i != inputLayers_.size(); ++i) { + REGISTER_TIMER_INFO("BwdConv3D", getName().c_str()); + if (weights_[i]->getWGrad()) { + bpropWeights(i); + } + if (this->needGradient_) { + bpropData(i); + } + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + weights_[i]->getParameterPtr()->incUpdate(callback); + } +} + +void Conv3DLayer::bpropWeights(int i) { + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + const MatrixPtr& inMat = getInputValue(i); + int width = inMat->getWidth(); + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + MatrixPtr wGradMat = weights_[i]->getWGrad(); + real* outGradData = getOutputGrad()->getData(); + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + + for (int n = 0; n < batchSize; ++n) { + colBuf_->vol2Col(inMat->getData() + n * width, channels_[i], + imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], + filterSizeZ_[i], filterSizeY_[i], filterSize_[i], + strideZ_[i], strideY_[i], stride_[i], + paddingZ_[i], paddingY_[i], padding_[i]); + outGradData += n * getOutputGrad()->getWidth(); + MatrixPtr outGradSub = + Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr inMatSub = colBuf_->subMatrix(g * K, K); + MatrixPtr outG = outGradSub->subMatrix(g * M, M); + MatrixPtr wGradSub = wGradMat->subMatrix(g * M, M); + wGradSub->mul(*outG, *(inMatSub->getTranspose()), 1.0, 1.0); + } + } +} + +void Conv3DLayer::bpropData(int i) { + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + MatrixPtr wMat = weights_[i]->getW(); + real* outGradData = getOutputGrad()->getData(); + real* preGradData = getInputGrad(i)->getData(); + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + for (int n = 0; n < batchSize; ++n) { + outGradData += n * getOutputGrad()->getWidth(); + preGradData += n * getInputGrad(i)->getWidth(); + MatrixPtr outGradSub = + Matrix::create(outGradData, M * groups_[i], N, false, useGpu_); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr wMatSub = wMat->subMatrix(g * M, M); + MatrixPtr outG = outGradSub->subMatrix(g * M, M); + MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K); + inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0); + } + colBuf_->col2Vol(preGradData, channels_[i], + imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], + filterSizeZ_[i], filterSizeY_[i], filterSize_[i], + strideZ_[i], strideY_[i], stride_[i], + paddingZ_[i], paddingY_[i], padding_[i], + 1.0, 1.0); + } +} + +void Conv3DLayer::bpropBiases() { + MatrixPtr outGradMat = getOutputGrad(); + if (this->sharedBiases_) { + biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); + } else { + biases_->getWGrad()->collectBias(*outGradMat, 1.0f); + } +} + +void Conv3DLayer::addBias() { + MatrixPtr outMat = getOutputValue(); + + if (this->sharedBiases_) { + outMat->addSharedBias(*(biases_->getW()), 1.0f); + } else { + outMat->addBias(*(biases_->getW()), 1.0f); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/Conv3DLayer.h b/paddle/gserver/layers/Conv3DLayer.h new file mode 100644 index 0000000000000..703671e5d0dcb --- /dev/null +++ b/paddle/gserver/layers/Conv3DLayer.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#pragma once + +#include "ConvBaseLayer.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/MathUtils.h" +#include + +namespace paddle { + +/** + * @brief A subclass of convolution layer. + * This layer expands input and use matrix multiplication to + * calculate convolution operation. + */ +class Conv3DLayer : public ConvBaseLayer { +public: + explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} + + ~Conv3DLayer() {} + + bool init(const LayerMap &layerMap, const ParameterMap ¶meterMap); + + size_t getSize(); + + void forward(PassType passType); + void addBias(); + + void backward(const UpdateCallback& callback); + + void bpropBiases(); + void bpropData(int i); + void bpropWeights(int i); + +protected: + // Figure out the dimensions for individual gemms. + IntV M_; /// numFilters_ / filter_group_; + IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ + IntV K_; /// outputD_ * outputH_ * outputW_ + MatrixPtr colBuf_; +}; + +} // namespace paddle From 23cf0c61e066f54b360efc4e17576a056868b050 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 13 Aug 2017 09:48:59 +0800 Subject: [PATCH 05/20] Add DeConv3DLayer --- paddle/gserver/layers/DeConv3DLayer.cpp | 211 ++++++++++++++++++++++++ paddle/gserver/layers/DeConv3DLayer.h | 58 +++++++ 2 files changed, 269 insertions(+) create mode 100644 paddle/gserver/layers/DeConv3DLayer.cpp create mode 100644 paddle/gserver/layers/DeConv3DLayer.h diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp new file mode 100644 index 0000000000000..8de40b681d72c --- /dev/null +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -0,0 +1,211 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" +#include "DeConv3DLayer.h" + +namespace paddle { + +REGISTER_LAYER(deconv3d, DeConv3DLayer); + +#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \ + (((IN_SIZE) - 1) * (STRID) - 2 * (PAD) + (KSIZE)) + +bool DeConv3DLayer::init(const LayerMap &layerMap, + const ParameterMap ¶meterMap) { + if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; + // for Deconv, the dimension of Kernel is + // channel * output * depth * height * weigth + // Matrix storage format: (output * depth * height * weigth) x channel + for (int index = 0; index < config_.inputs().size(); ++index) { + M_.push_back(filterChannels_[index]); + K_.push_back( + filterPixels_[index] * (numFilters_/groups_[index])); + weights_[index]->getW()->reshape( + filterPixels_[index] * numFilters_, + filterChannels_[index]); + weights_[index]->getWGrad()->reshape( + filterPixels_[index] * numFilters_, + filterChannels_[index]); + } + biases_->getWGrad()->reshape( + biases_->getWGrad()->width_, biases_->getWGrad()->height_); + biases_->getW()->reshape( + biases_->getW()->width_, biases_->getW()->height_); + CHECK(inputLayers_.size() == parameters_.size()); + return true; +} + + +size_t DeConv3DLayer::getSize() { + CHECK_NE(inputLayers_.size(), 0UL); + // imgSizeH_.clear(); + // imgSizeW_.clear(); + // imgSizeD_.clear(); + outputH_.clear(); + outputW_.clear(); + outputD_.clear(); + N_.clear(); + No_.clear(); + size_t layerSize = 0; + for (size_t i = 0; i < inputLayers_.size(); ++i) { + // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); + // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); + // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); + outputW_.push_back( + DECONV_OUTPUT_SIZE( + imgSizeW_[i], stride_[i], + padding_[i], filterSize_[i])); + outputH_.push_back( + DECONV_OUTPUT_SIZE( + imgSizeH_[i], strideY_[i], + paddingY_[i], filterSizeY_[i])); + outputD_.push_back( + DECONV_OUTPUT_SIZE( + imgSizeD_[i], strideZ_[i], + paddingZ_[i], filterSizeZ_[i])); + No_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); + N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); + CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); + layerSize += No_[i] * numFilters_; + } + getOutput().setFrameHeight(outputH_[0]); + getOutput().setFrameWidth(outputW_[0]); + getOutput().setFrameDepth(outputD_[0]); + return layerSize; +} + +void DeConv3DLayer::forward(PassType passType) { + Layer::forward(passType); + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + int outWidth = getSize(); + resetOutput(batchSize, outWidth); + const MatrixPtr outMat = getOutputValue(); + + for (size_t i = 0; i != inputLayers_.size(); ++i) { + REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str()); + const MatrixPtr& inMat = getInputValue(i); + int width = inMat->getWidth(); + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + MatrixPtr wMat = weights_[i]->getW(); + Matrix::resizeOrCreate(colBuf_, K * groups_[i] , N, false, useGpu_); + + for (int n = 0; n < batchSize; ++n) { + real *inData = inMat->getData() + n * width; + real *colBufData = colBuf_->getData(); + for (int g = 0; g < groups_[i]; g++) { + MatrixPtr wMatSub = wMat->subMatrix(g * K, K); + MatrixPtr inMatSub = + Matrix::create(inData, M, N, false, useGpu_); + MatrixPtr colBufDataSub = + Matrix::create(colBufData, K, N, false, useGpu_); + colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0); + colBufData += K * N; + inData += M * N; + } + colBuf_->col2Vol(outMat->getData()+ n * outMat->getWidth(), + numFilters_, outputD_[i], outputH_[i], outputW_[i], + filterSizeZ_[i], filterSizeY_[i], filterSize_[i], + strideZ_[i], strideY_[i], stride_[i], + paddingZ_[i], paddingY_[i], padding_[i], 1.0, 1.0); + } + } + if (nullptr != this->biasParameter_) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + this->addBias(); + } + forwardActivation(); +} + +void DeConv3DLayer::backward(const UpdateCallback &callback) { + backwardActivation(); + int batchSize = getOutputGrad()->getHeight(); + int outputWidth = getOutputGrad()->getWidth(); + if (biases_ && biases_->getWGrad()) { + bpropBiases(); + biases_->getParameterPtr()->incUpdate(callback); + } + for (size_t i =0; i < inputLayers_.size(); ++i) { + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + const MatrixPtr& inMat = getInputValue(i); + for (int n = 0; n < batchSize; ++n) { + REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str()); + if (weights_[i]->getWGrad() || this->needGradient_) { + colBuf_->vol2Col(getOutputGrad()->getData() + n * outputWidth, + numFilters_, outputD_[i], outputH_[i], outputW_[i], + filterSizeZ_[i], filterSizeY_[i], filterSize_[i], + strideZ_[i], strideY_[i], stride_[i], + paddingZ_[i], paddingY_[i], padding_[i]); + } + if (weights_[i]->getWGrad()) { + real *inData = inMat->getData() + n * inMat->getWidth();; + real *wGradData = weights_[i]->getWGrad()->getData(); + for (int g = 0; g < groups_[i]; g++) { + MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K); + MatrixPtr inMatSub = Matrix::create( + inData, M, N, false, useGpu_); + MatrixPtr wGradMatSub = Matrix::create( + wGradData, K, M, false, useGpu_); + wGradMatSub->mul(*colBufDataSub, + *(inMatSub->getTranspose()), 1.0, 1.0); + wGradData += K * M; + inData += M * N; + } + weights_[i]->getParameterPtr()->incUpdate(callback); + } + if (this->needGradient_) { + real* preGrad = getInputGrad(i)->getData(); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K); + MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K); + MatrixPtr inGradMatSub = Matrix::create( + preGrad, M, N, false, useGpu_); + inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 0.0); + preGrad += M * N; + } + } + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + } + } +} + +void DeConv3DLayer::bpropWeights(int i) { } +void DeConv3DLayer::bpropData(int i) { } + +void DeConv3DLayer::bpropBiases() { + MatrixPtr outGradMat = getOutputGrad(); + + if (this->sharedBiases_) { + biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); + } else { + biases_->getWGrad()->collectBias(*outGradMat, 1.0f); + } +} + +void DeConv3DLayer::addBias() { + MatrixPtr outMat = getOutputValue(); + if (this->sharedBiases_) { + outMat->addSharedBias(*(biases_->getW()), 1.0f); + } else { + outMat->addBias(*(biases_->getW()), 1.0f); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/DeConv3DLayer.h b/paddle/gserver/layers/DeConv3DLayer.h new file mode 100644 index 0000000000000..435807fe5defa --- /dev/null +++ b/paddle/gserver/layers/DeConv3DLayer.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#pragma once + +#include "ConvBaseLayer.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/MathUtils.h" +#include + +namespace paddle { + +/** + * @brief A subclass of deconvolution3D layer. + * This layer expands input and use matrix multiplication to + * calculate deconvolution3D operation. + */ +class DeConv3DLayer : public ConvBaseLayer { +public: + explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} + + ~DeConv3DLayer() {} + + bool init(const LayerMap &layerMap, const ParameterMap ¶meterMap); + + size_t getSize(); + + void forward(PassType passType); + void addBias(); + + void backward(const UpdateCallback& callback); + + void bpropBiases(); + void bpropData(int i); + void bpropWeights(int i); + +protected: + // Figure out the dimensions for individual gemms. + IntV M_; /// numFilters_ / filter_group_; + IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ + IntV K_; /// outputD_ * outputH_ * outputW_ + IntV No_; + MatrixPtr colBuf_; +}; + +} // namespace paddle From 52ceeedba5ca1371302414a0ad11ff93d9ed7d9a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 13 Aug 2017 09:51:39 +0800 Subject: [PATCH 06/20] Add col2vol and vol2col CPU funtion --- paddle/math/Matrix.cpp | 135 +++++++++++++++++++++++++++++++++++++++++ paddle/math/Matrix.h | 64 +++++++++++++++++++ 2 files changed, 199 insertions(+) diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 27f7d95b752d4..66868e73b33a2 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1389,6 +1389,52 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) { output_d, grad_d, mat_d, height_, width_); } +void GpuMatrix::vol2Col(real* data, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + hl_matrix_vol2Col(data, + channels, depth, height, width, + filterD, filterH, filterW, + strideD, strideH, strideW, + paddingD, paddingH, paddingW, getData()); +} + +void GpuMatrix::col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + hl_matrix_col2Vol(trg, + channels, depth, height, width, + filterD, filterH, filterW, + strideD, strideH, strideW, + paddingD, paddingH, paddingW, + getData(), + alpha, beta); + } + /** * CpuMatrix */ @@ -3975,6 +4021,95 @@ void CpuMatrix::bilinearBackward(const Matrix& out, } } +void CpuMatrix::vol2Col(real* data, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + real* outData = getData(); + int outHeight = (height + 2 * paddingH - filterH) / strideH + 1; + int outWidth = (width + 2 * paddingW - filterW) / strideW + 1; + int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1; + + int channelsCol = channels * filterD * filterH * filterW; + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterW; + int hOffset = (c / filterW) % filterH; + int dOffset = (c / filterW / filterH) % filterD; + int cIn = c / filterW / filterH / filterD; + for (int d = 0; d < outDepth; ++d) { + for (int h = 0; h < outHeight; ++h) { + for (int w = 0; w < outWidth; ++w) { + int dPad = d * strideD - paddingD + dOffset; + int hPad = h * strideH - paddingH + hOffset; + int wPad = w * strideW - paddingW + wOffset; + + if (hPad >= 0 && hPad < height && wPad >= 0 && wPad < width && + dPad >= 0 && dPad < depth) + outData[((c * outDepth + d) * outHeight + h) * outWidth + w] = + data[((cIn * depth + dPad) * height + hPad) * width + wPad]; + else + outData[((c * outDepth + d) * outHeight + h) * outWidth + w] = 0; + } + } + } + } +} + +void CpuMatrix::col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + real* src = getData(); + int outDepth = (depth + 2 * paddingH - filterD) / strideD + 1; + int outHeight = (height + 2 * paddingH - filterH) / strideH + 1; + int outWidth = (width + 2 * paddingW - filterW) / strideW + 1; + int channelsCol = channels * filterD * filterH * filterW; + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterW; + int hOffset = (c / filterW) % filterH; + int dOffset = (c / filterW / filterH) % filterD; + int cIm = c / filterW / filterH / filterD; + for (int d = 0; d < outDepth; ++d) { + for (int h = 0; h < outHeight; ++h) { + for (int w = 0; w < outWidth; ++w) { + int dPad = d * strideD - paddingD + dOffset; + int hPad = h * strideH - paddingH + hOffset; + int wPad = w * strideW - paddingW + wOffset; + if (hPad >= 0 && hPad < height && wPad >= 0 && wPad < width && + dPad >= 0 && dPad < depth) + trg[((cIm * depth + dPad) * height + hPad) * width + wPad] = + alpha * + src[((c * outDepth + d) * outHeight + h) * outWidth + w] + + beta * + trg[((cIm * depth + dPad) * height + hPad) * width + wPad]; + } + } + } + } +} + //////////////////////////////////////////////////////////////// // functions executed via cpu // //////////////////////////////////////////////////////////////// diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index bb802bbb2c752..4354996ce0b02 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1039,6 +1039,42 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemented"; } + virtual void vol2Col(real* data, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + LOG(FATAL) << "Not implemeted"; + } + + virtual void col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + LOG(FATAL) << "Not implemeted"; + } + virtual void bilinearForward(const Matrix& in, const size_t inImgH, const size_t inImgW, @@ -1374,6 +1410,20 @@ class GpuMatrix : public Matrix { const real ratioH, const real ratioW); + void vol2Col(real* data, + int channels, + int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW); + + void col2Vol(real* trg, + int channels, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + real alpha, real beta); + void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label); @@ -1715,6 +1765,20 @@ class CpuMatrix : public Matrix { const real ratioH, const real ratioW); + void vol2Col(real* data, + int channels, + int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW); + + void col2Vol(real* trg, + int channels, int depth, int height, int width, + int filterD, int filterH, int filterW, + int strideD, int strideH, int strideW, + int paddingD, int paddingH, int paddingW, + real alpha, real beta); + template void operator=(const ExpressionType& expr) { TensorCpuApply(*this, expr); From 9b3d6acdbfc2fd6bc26185ddb9c38dfb90632324 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sun, 13 Aug 2017 09:54:10 +0800 Subject: [PATCH 07/20] Add depth dimension information to Argument --- paddle/parameter/Argument.cpp | 2 ++ paddle/parameter/Argument.h | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 0547ac93cd183..77fd0c5890b45 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -186,6 +186,7 @@ void Argument::resizeAndCopyFrom(const Argument& src, resizeAndCopy(strs, src.strs, useGpu, stream); frameWidth = src.frameWidth; frameHeight = src.frameHeight; + frameDepth = src.frameDepth; } int32_t Argument::resizeAndCopyFrom(const Argument& src, @@ -206,6 +207,7 @@ int32_t Argument::resizeAndCopyFrom(const Argument& src, dataId = src.dataId; frameWidth = src.frameWidth; frameHeight = src.frameHeight; + frameDepth = src.frameDepth; if (!src.sequenceStartPositions) { // non-sequence input, copy samples directly diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index d8d7a4398f99a..ba3ad2fd4d992 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -35,6 +32,7 @@ struct Argument { strs(nullptr), frameHeight(0), frameWidth(0), + frameDepth(0), sequenceStartPositions(nullptr), subSequenceStartPositions(nullptr), cpuSequenceDims(nullptr), @@ -64,6 +62,7 @@ struct Argument { allCount = argument.allCount; frameHeight = argument.frameHeight; frameWidth = argument.frameWidth; + frameDepth = argument.frameDepth; dataId = argument.dataId; } @@ -76,6 +75,7 @@ struct Argument { // A dataBatch includes batchSize frames, one frame maybe not only vector size_t frameHeight; size_t frameWidth; + size_t frameDepth; // If NULL, each position is treated independently. // Otherwise, its size should be #NumberOfSequences + 1. @@ -136,8 +136,10 @@ struct Argument { } size_t getFrameHeight() const { return frameHeight; } size_t getFrameWidth() const { return frameWidth; } + size_t getFrameDepth() const { return frameDepth; } void setFrameHeight(size_t h) { frameHeight = h; } void setFrameWidth(size_t w) { frameWidth = w; } + void setFrameDepth(size_t d) { frameDepth = d; } int64_t getNumSequences() const { return sequenceStartPositions ? sequenceStartPositions->getSize() - 1 From 424b325d084ef0fd5aa61996f35ef88126c48306 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 18 Aug 2017 14:10:27 +0800 Subject: [PATCH 08/20] add unit test DeConv3D, Conv3D, col2vol, vol2col --- paddle/gserver/tests/test_LayerGrad.cpp | 152 +++++++++++++++++++++++ paddle/math/tests/test_matrixCompare.cpp | 116 +++++++++++++++++ 2 files changed, 268 insertions(+) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0f312b6ca50bc..1e80e2c0ee0f8 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2007,6 +2007,158 @@ TEST(Layer, RowL2NormLayer) { } } +void test3DConvLayer(const string& type, bool trans, bool useGpu) { + // filter size + const int NUM_FILTERS = 6; + // const int CHANNELS = 3; + const int FILTER_SIZE = 3; + const int FILTER_SIZE_Y = 3; + const int FILTER_SIZE_Z = 3; + + // input image + const int CHANNELS = 3; + const int IMAGE_SIZE = 9; + const int IMAGE_SIZE_Y = 9; + const int IMAGE_SIZE_Z = 9; // 2, 3, 5, 5, 5 + + TestConfig config; + config.biasSize = NUM_FILTERS; + config.layerConfig.set_type(type); + config.layerConfig.set_num_filters(NUM_FILTERS); + config.layerConfig.set_partial_sum(1); + config.layerConfig.set_shared_biases(true); + + // Setting up conv3D-trans layer + LayerInputConfig* input = config.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + + conv->set_channels(CHANNELS); + conv->set_filter_size(FILTER_SIZE); + conv->set_filter_size_y(FILTER_SIZE_Y); + conv->set_filter_size_z(FILTER_SIZE_Z); + conv->set_padding(0); + conv->set_padding_y(0); + conv->set_padding_z(0); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_stride_z(2); + conv->set_img_size(IMAGE_SIZE); + conv->set_img_size_y(IMAGE_SIZE_Y); + conv->set_img_size_z(IMAGE_SIZE_Z); + conv->set_output_x(outputSize(conv->img_size(), + conv->filter_size(), + conv->padding(), + conv->stride(), + /* caffeMode */ true)); + conv->set_output_y(outputSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + /* caffeMode */ true)); + conv->set_output_z(outputSize(conv->img_size_z(), + conv->filter_size_z(), + conv->padding_z(), + conv->stride_z(), + /* caffeMode */ true)); + + config.layerConfig.set_size(conv->output_x() * conv->output_y() * + conv->output_z() * NUM_FILTERS); + conv->set_groups(1); + conv->set_filter_channels(conv->channels() / conv->groups()); + config.inputDefs.push_back( + {INPUT_DATA, + "layer_0", + CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z, + conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z * + NUM_FILTERS}); + + testLayerGrad(config, "conv3D", 10, trans, useGpu); + // Use small batch_size and useWeight=true to test biasGrad + testLayerGrad(config, "conv3D", 2, trans, useGpu, true, 0.02); +} + +TEST(Layer, test3DConvLayer) { + test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ false); +#ifndef PADDLE_ONLY_CPU + test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ true); +#endif +} + +int deConvOutputSize(int inSize, int kSize, int pad, int stride) { + return (inSize - 1) * stride - 2 * pad + kSize; +} + +void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { + // filter size + const int NUM_FILTERS = 6; + // const int CHANNELS = 3; + const int FILTER_SIZE = 3; + const int FILTER_SIZE_Y = 3; + const int FILTER_SIZE_Z = 3; + + // input image + const int CHANNELS = 3; + const int IMAGE_SIZE = 4; + const int IMAGE_SIZE_Y = 6; + const int IMAGE_SIZE_Z = 6; + + // Setting up conv-trans layer + TestConfig config; + config.biasSize = NUM_FILTERS; + config.layerConfig.set_type("deconv3d"); + config.layerConfig.set_num_filters(NUM_FILTERS); + config.layerConfig.set_partial_sum(1); + config.layerConfig.set_shared_biases(true); + + LayerInputConfig* input = config.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + + conv->set_channels(CHANNELS); + conv->set_filter_size(FILTER_SIZE); + conv->set_filter_size_y(FILTER_SIZE_Y); + conv->set_filter_size_z(FILTER_SIZE_Z); + conv->set_padding(0); + conv->set_padding_y(0); + conv->set_padding_z(0); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_stride_z(2); + conv->set_img_size(IMAGE_SIZE); + conv->set_img_size_y(IMAGE_SIZE_Y); + conv->set_img_size_z(IMAGE_SIZE_Z); + conv->set_output_x(deConvOutputSize( + conv->img_size(), conv->filter_size(), conv->padding(), conv->stride())); + conv->set_output_y(deConvOutputSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y())); + conv->set_output_z(deConvOutputSize(conv->img_size_z(), + conv->filter_size_z(), + conv->padding_z(), + conv->stride_z())); + config.layerConfig.set_size(conv->output_x() * conv->output_y() * + conv->output_z() * NUM_FILTERS); + conv->set_groups(1); + conv->set_filter_channels(conv->channels() / conv->groups()); + config.inputDefs.push_back( + {INPUT_DATA, + "layer_0", + CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z, + conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z * + NUM_FILTERS}); + + testLayerGrad(config, "deconv3D", 10, trans, useGpu); + // Use small batch_size and useWeight=true to test biasGrad + testLayerGrad(config, "deconv3D", 2, trans, useGpu, true, 0.02); +} + +TEST(Layer, test3DDeConvLayer) { + test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ false); +#ifndef PADDLE_ONLY_CPU + test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ true); +#endif +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index d77478f345df9..1d41ec087028a 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1203,4 +1203,120 @@ TEST(Matrix, warpCTC) { } } +int outputSizeCol2Vol( + int imageSize, int filterSize, int padding, int stride, bool caffeMode) { + int outputSize; + if (!caffeMode) { + outputSize = + (imageSize - filterSize + 2 * padding + stride - 1) / stride + 1; + } else { + outputSize = (imageSize - filterSize + 2 * padding) / stride + 1; + } + CHECK_GE(outputSize, 1); + return outputSize; +} + +void testMatrixCol2Vol(int depth, int height, int width) { + int channel = 3; + int filterX = 3, filterY = 4, filterZ = 5; + int strideX = 2, strideY = 2, strideZ = 2; + int padX = 1, padY = 1, padZ = 1; + + MatrixPtr cpuImage = + std::make_shared(channel, depth * height * width); + MatrixPtr gpuImage = + std::make_shared(channel, depth * height * width); + cpuImage->randomizeUniform(); + gpuImage->copyFrom(*cpuImage); + + int outD = outputSizeCol2Vol(depth, filterZ, padZ, strideZ, true); + int outH = outputSizeCol2Vol(height, filterY, padZ, strideY, true); + int outW = outputSizeCol2Vol(width, filterX, padZ, strideX, true); + + int colBufHeight = channel * filterZ * filterY * filterX; + int colBufWidth = outD * outH * outW; + MatrixPtr cpuColBuf = std::make_shared(colBufHeight, colBufWidth); + MatrixPtr gpuColBuf = std::make_shared(colBufHeight, colBufWidth); + cpuColBuf->vol2Col(cpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX); + gpuColBuf->vol2Col(gpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX); + TensorCheckEqual(*cpuColBuf, *gpuColBuf); + + cpuColBuf->randomizeUniform(); + gpuColBuf->copyFrom(*cpuColBuf); + cpuColBuf->col2Vol(cpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX, + 1.0, + 1.0); + gpuColBuf->col2Vol(gpuImage->getData(), + channel, + depth, + height, + width, + filterZ, + filterY, + filterX, + strideZ, + strideY, + strideX, + padZ, + padY, + padX, + 1.0, + 1.0); + TensorCheckErr(*cpuImage, *gpuImage); +} + +TEST(Matrix, col2Vol) { + for (auto depth : {9, 16, 64, 128}) { + for (auto height : {9, 11, 73, 128, 256}) { + for (auto width : { + 9, 32, 100, 512, + }) { + VLOG(3) << "depth=" << depth << " height=" << height + << " width=" << width; + testMatrixCol2Vol(depth, height, width); + } + } + } +} +/////// + #endif From c792ef7d5ae470031bebcd990b79c0ce7f36f7bc Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 18 Aug 2017 14:12:01 +0800 Subject: [PATCH 09/20] fix DeConv3D, Conv3D --- paddle/gserver/layers/Conv3DLayer.cpp | 248 +++++++++++++----------- paddle/gserver/layers/DeConv3DLayer.cpp | 186 +++++++++--------- 2 files changed, 229 insertions(+), 205 deletions(-) diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp index 0fa9c5f9f56df..5609a4cc73562 100644 --- a/paddle/gserver/layers/Conv3DLayer.cpp +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "Conv3DLayer.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" -#include "Conv3DLayer.h" namespace paddle { @@ -22,32 +22,30 @@ REGISTER_LAYER(conv3d, Conv3DLayer); bool Conv3DLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { - if (!ConvBaseLayer::init(layerMap, parameterMap)) - return false; + if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; int index = 0; for (auto &inputConfig : config_.inputs()) { - const ConvConfig &conf = inputConfig.conv_conf(); - M_.push_back(numFilters_ / conf.groups()); - K_.push_back( - conf.filter_channels() * conf.filter_size_z() * \ - conf.filter_size_y() * conf.filter_size()); - weights_[index]->getW()->reshape( - weights_[index]->getW()->getWidth(), - weights_[index]->getW()->getHeight()); + const ConvConfig &conf = inputConfig.conv_conf(); + M_.push_back(numFilters_ / conf.groups()); + K_.push_back(filterPixels_[index] * filterChannels_[index]); + if (nullptr != weights_[index]->getW()) + weights_[index]->getW()->reshape(weights_[index]->getW()->getWidth(), + weights_[index]->getW()->getHeight()); + if (nullptr != weights_[index]->getWGrad()) weights_[index]->getWGrad()->reshape( - weights_[index]->getWGrad()->getWidth(), - weights_[index]->getWGrad()->getHeight()); - ++index; + weights_[index]->getWGrad()->getWidth(), + weights_[index]->getWGrad()->getHeight()); + ++index; } - biases_->getWGrad()->reshape( - biases_->getWGrad()->width_, biases_->getWGrad()->height_); - biases_->getW()->reshape( - biases_->getW()->width_, biases_->getW()->height_); + if (nullptr != biases_->getWGrad()) + biases_->getWGrad()->reshape(biases_->getWGrad()->width_, + biases_->getWGrad()->height_); + if (nullptr != biases_->getW()) + biases_->getW()->reshape(biases_->getW()->width_, biases_->getW()->height_); CHECK(inputLayers_.size() == parameters_.size()); return true; } - size_t Conv3DLayer::getSize() { CHECK_NE(inputLayers_.size(), 0UL); // imgSizeH_.clear(); @@ -59,22 +57,19 @@ size_t Conv3DLayer::getSize() { N_.clear(); size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); ++i) { - // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); - // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); - // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); - outputW_.push_back(outputSize( - imgSizeW_[i], filterSize_[i], - padding_[i], stride_[i], true)); - outputH_.push_back(outputSize( - imgSizeH_[i], filterSizeY_[i], - paddingY_[i], strideY_[i], true)); - outputD_.push_back(outputSize( - imgSizeD_[i], filterSizeZ_[i], - paddingZ_[i], strideZ_[i], true)); - - N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); - CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); - layerSize += N_[i] * numFilters_; + // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); + // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); + // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); + outputW_.push_back(outputSize( + imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); + outputH_.push_back(outputSize( + imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); + outputD_.push_back(outputSize( + imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); + + N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); + CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); + layerSize += N_[i] * numFilters_; } getOutput().setFrameHeight(outputH_[0]); getOutput().setFrameWidth(outputW_[0]); @@ -88,38 +83,46 @@ void Conv3DLayer::forward(PassType passType) { int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); int outWidth = getSize(); resetOutput(batchSize, outWidth); - const MatrixPtr outMat = getOutputValue(); for (size_t i = 0; i != inputLayers_.size(); ++i) { - REGISTER_TIMER_INFO("FwdConv3D", getName().c_str()); - const MatrixPtr& inMat = getInputValue(i); - int width = inMat->getWidth(); - int M = M_[i]; - int N = N_[i]; - int K = K_[i]; - Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); - MatrixPtr wMat = weights_[i]->getW(); - for (int n = 0; n < batchSize; ++n) { - colBuf_->vol2Col(inMat->getData() + n * width, channels_[i], - imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], - filterSizeZ_[i], filterSizeY_[i], filterSize_[i], - strideZ_[i], strideY_[i], stride_[i], - paddingZ_[i], paddingY_[i], padding_[i]); - - real *outData = outMat->getData() + n * outWidth; - MatrixPtr outMatSub = - Matrix::create(outData, groups_[i] * M, N, false, useGpu_); - for (int g = 0; g < groups_[i]; g++) { - MatrixPtr wMatSub = wMat->subMatrix(g * M, M); - MatrixPtr in = colBuf_->subMatrix(g * K, K); - MatrixPtr out = outMatSub->subMatrix(g * M, M); - out->mul(*wMatSub, *in, 1.0, 0.0); - } + REGISTER_TIMER_INFO("FwdConv3D", getName().c_str()); + const MatrixPtr &inMat = getInputValue(i); + const MatrixPtr &outMat = getOutputValue(); + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + MatrixPtr wMat = weights_[i]->getW(); + for (int n = 0; n < batchSize; ++n) { + colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(), + channels_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i]); + + real *outData = outMat->getData() + n * outMat->getStride(); + MatrixPtr outMatSub = + Matrix::create(outData, groups_[i] * M, N, false, useGpu_); + for (int g = 0; g < groups_[i]; g++) { + MatrixPtr wMatSub = wMat->subMatrix(g * M, M); + MatrixPtr in = colBuf_->subMatrix(g * K, K); + MatrixPtr out = outMatSub->subMatrix(g * M, M); + out->mul(*wMatSub, *in, 1.0, 1.0); } + } } if (nullptr != this->biasParameter_) { - REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); - this->addBias(); + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + this->addBias(); } forwardActivation(); } @@ -128,20 +131,20 @@ void Conv3DLayer::backward(const UpdateCallback &callback) { backwardActivation(); if (biases_ && biases_->getWGrad()) { - bpropBiases(); - biases_->getParameterPtr()->incUpdate(callback); + bpropBiases(); + biases_->getParameterPtr()->incUpdate(callback); } for (size_t i = 0; i != inputLayers_.size(); ++i) { - REGISTER_TIMER_INFO("BwdConv3D", getName().c_str()); - if (weights_[i]->getWGrad()) { - bpropWeights(i); - } - if (this->needGradient_) { - bpropData(i); - } - REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); - weights_[i]->getParameterPtr()->incUpdate(callback); + REGISTER_TIMER_INFO("BwdConv3D", getName().c_str()); + if (weights_[i]->getWGrad()) { + bpropWeights(i); + } + if (getInputGrad(i)) { + bpropData(i); + } + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + weights_[i]->getParameterPtr()->incUpdate(callback); } } @@ -149,28 +152,36 @@ void Conv3DLayer::bpropWeights(int i) { int M = M_[i]; int N = N_[i]; int K = K_[i]; - const MatrixPtr& inMat = getInputValue(i); - int width = inMat->getWidth(); + const MatrixPtr &inMat = getInputValue(i); Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); MatrixPtr wGradMat = weights_[i]->getWGrad(); - real* outGradData = getOutputGrad()->getData(); int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - for (int n = 0; n < batchSize; ++n) { - colBuf_->vol2Col(inMat->getData() + n * width, channels_[i], - imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], - filterSizeZ_[i], filterSizeY_[i], filterSize_[i], - strideZ_[i], strideY_[i], stride_[i], - paddingZ_[i], paddingY_[i], padding_[i]); - outGradData += n * getOutputGrad()->getWidth(); - MatrixPtr outGradSub = - Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_); - for (int g = 0; g < groups_[i]; ++g) { - MatrixPtr inMatSub = colBuf_->subMatrix(g * K, K); - MatrixPtr outG = outGradSub->subMatrix(g * M, M); - MatrixPtr wGradSub = wGradMat->subMatrix(g * M, M); - wGradSub->mul(*outG, *(inMatSub->getTranspose()), 1.0, 1.0); - } + colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(), + channels_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i]); + + real *outGradData = + getOutputGrad()->getData() + n * getOutputGrad()->getStride(); + MatrixPtr outGradSub = + Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr inMatSub = colBuf_->subMatrix(g * K, K); + MatrixPtr outG = outGradSub->subMatrix(g * M, M); + MatrixPtr wGradSub = wGradMat->subMatrix(g * M, M); + wGradSub->mul(*outG, *(inMatSub->getTranspose()), 1.0, 1.0); + } } } @@ -180,45 +191,54 @@ void Conv3DLayer::bpropData(int i) { int K = K_[i]; Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); MatrixPtr wMat = weights_[i]->getW(); - real* outGradData = getOutputGrad()->getData(); - real* preGradData = getInputGrad(i)->getData(); int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); for (int n = 0; n < batchSize; ++n) { - outGradData += n * getOutputGrad()->getWidth(); - preGradData += n * getInputGrad(i)->getWidth(); - MatrixPtr outGradSub = - Matrix::create(outGradData, M * groups_[i], N, false, useGpu_); - for (int g = 0; g < groups_[i]; ++g) { - MatrixPtr wMatSub = wMat->subMatrix(g * M, M); - MatrixPtr outG = outGradSub->subMatrix(g * M, M); - MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K); - inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0); - } - colBuf_->col2Vol(preGradData, channels_[i], - imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], - filterSizeZ_[i], filterSizeY_[i], filterSize_[i], - strideZ_[i], strideY_[i], stride_[i], - paddingZ_[i], paddingY_[i], padding_[i], - 1.0, 1.0); + real *outGradData = + getOutputGrad()->getData() + n * getOutputGrad()->getStride(); + real *preGradData = + getInputGrad(i)->getData() + n * getInputGrad(i)->getStride(); + MatrixPtr outGradSub = + Matrix::create(outGradData, M * groups_[i], N, false, useGpu_); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr wMatSub = wMat->subMatrix(g * M, M); + MatrixPtr outG = outGradSub->subMatrix(g * M, M); + MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K); + inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0); + } + colBuf_->col2Vol(preGradData, + channels_[i], + imgSizeD_[i], + imgSizeH_[i], + imgSizeW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i], + 1.0, + 1.0); } } void Conv3DLayer::bpropBiases() { MatrixPtr outGradMat = getOutputGrad(); if (this->sharedBiases_) { - biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); + biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); } else { - biases_->getWGrad()->collectBias(*outGradMat, 1.0f); + biases_->getWGrad()->collectBias(*outGradMat, 1.0f); } } void Conv3DLayer::addBias() { MatrixPtr outMat = getOutputValue(); - if (this->sharedBiases_) { - outMat->addSharedBias(*(biases_->getW()), 1.0f); + outMat->addSharedBias(*(biases_->getW()), 1.0f); } else { - outMat->addBias(*(biases_->getW()), 1.0f); + outMat->addBias(*(biases_->getW()), 1.0f); } } diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp index 8de40b681d72c..286f5b985c38e 100644 --- a/paddle/gserver/layers/DeConv3DLayer.cpp +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -12,43 +12,42 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "DeConv3DLayer.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" -#include "DeConv3DLayer.h" namespace paddle { REGISTER_LAYER(deconv3d, DeConv3DLayer); #define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \ - (((IN_SIZE) - 1) * (STRID) - 2 * (PAD) + (KSIZE)) + (((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE)) bool DeConv3DLayer::init(const LayerMap &layerMap, - const ParameterMap ¶meterMap) { + const ParameterMap ¶meterMap) { if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; // for Deconv, the dimension of Kernel is // channel * output * depth * height * weigth // Matrix storage format: (output * depth * height * weigth) x channel for (int index = 0; index < config_.inputs().size(); ++index) { M_.push_back(filterChannels_[index]); - K_.push_back( - filterPixels_[index] * (numFilters_/groups_[index])); - weights_[index]->getW()->reshape( - filterPixels_[index] * numFilters_, - filterChannels_[index]); - weights_[index]->getWGrad()->reshape( - filterPixels_[index] * numFilters_, - filterChannels_[index]); + K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index])); + if (weights_[index]->getW()) + weights_[index]->getW()->reshape(filterPixels_[index] * numFilters_, + filterChannels_[index]); + if (weights_[index]->getWGrad()) + weights_[index]->getWGrad()->reshape(filterPixels_[index] * numFilters_, + filterChannels_[index]); } - biases_->getWGrad()->reshape( - biases_->getWGrad()->width_, biases_->getWGrad()->height_); - biases_->getW()->reshape( - biases_->getW()->width_, biases_->getW()->height_); + if (biases_->getWGrad()) + biases_->getWGrad()->reshape(biases_->getWGrad()->width_, + biases_->getWGrad()->height_); + if (biases_->getW()) + biases_->getW()->reshape(biases_->getW()->width_, biases_->getW()->height_); CHECK(inputLayers_.size() == parameters_.size()); return true; } - size_t DeConv3DLayer::getSize() { CHECK_NE(inputLayers_.size(), 0UL); // imgSizeH_.clear(); @@ -64,18 +63,12 @@ size_t DeConv3DLayer::getSize() { // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); - outputW_.push_back( - DECONV_OUTPUT_SIZE( - imgSizeW_[i], stride_[i], - padding_[i], filterSize_[i])); - outputH_.push_back( - DECONV_OUTPUT_SIZE( - imgSizeH_[i], strideY_[i], - paddingY_[i], filterSizeY_[i])); - outputD_.push_back( - DECONV_OUTPUT_SIZE( - imgSizeD_[i], strideZ_[i], - paddingZ_[i], filterSizeZ_[i])); + outputW_.push_back(DECONV_OUTPUT_SIZE( + imgSizeW_[i], stride_[i], padding_[i], filterSize_[i])); + outputH_.push_back(DECONV_OUTPUT_SIZE( + imgSizeH_[i], strideY_[i], paddingY_[i], filterSizeY_[i])); + outputD_.push_back(DECONV_OUTPUT_SIZE( + imgSizeD_[i], strideZ_[i], paddingZ_[i], filterSizeZ_[i])); No_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); @@ -96,32 +89,37 @@ void DeConv3DLayer::forward(PassType passType) { for (size_t i = 0; i != inputLayers_.size(); ++i) { REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str()); - const MatrixPtr& inMat = getInputValue(i); - int width = inMat->getWidth(); + const MatrixPtr &inMat = getInputValue(i); int M = M_[i]; int N = N_[i]; int K = K_[i]; MatrixPtr wMat = weights_[i]->getW(); - Matrix::resizeOrCreate(colBuf_, K * groups_[i] , N, false, useGpu_); - + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); for (int n = 0; n < batchSize; ++n) { - real *inData = inMat->getData() + n * width; - real *colBufData = colBuf_->getData(); - for (int g = 0; g < groups_[i]; g++) { - MatrixPtr wMatSub = wMat->subMatrix(g * K, K); - MatrixPtr inMatSub = - Matrix::create(inData, M, N, false, useGpu_); - MatrixPtr colBufDataSub = - Matrix::create(colBufData, K, N, false, useGpu_); - colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0); - colBufData += K * N; - inData += M * N; + real *inData = inMat->getData() + n * inMat->getStride(); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_); + MatrixPtr wMatSub = wMat->subMatrix(g * K, K); + MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K); + colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0); + inData += M * N; } - colBuf_->col2Vol(outMat->getData()+ n * outMat->getWidth(), - numFilters_, outputD_[i], outputH_[i], outputW_[i], - filterSizeZ_[i], filterSizeY_[i], filterSize_[i], - strideZ_[i], strideY_[i], stride_[i], - paddingZ_[i], paddingY_[i], padding_[i], 1.0, 1.0); + colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(), + numFilters_, + outputD_[i], + outputH_[i], + outputW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i], + 1.0, + 1.0); } } if (nullptr != this->biasParameter_) { @@ -134,63 +132,69 @@ void DeConv3DLayer::forward(PassType passType) { void DeConv3DLayer::backward(const UpdateCallback &callback) { backwardActivation(); int batchSize = getOutputGrad()->getHeight(); - int outputWidth = getOutputGrad()->getWidth(); if (biases_ && biases_->getWGrad()) { bpropBiases(); biases_->getParameterPtr()->incUpdate(callback); } - for (size_t i =0; i < inputLayers_.size(); ++i) { - int M = M_[i]; - int N = N_[i]; - int K = K_[i]; - Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); - const MatrixPtr& inMat = getInputValue(i); - for (int n = 0; n < batchSize; ++n) { + for (size_t i = 0; i < inputLayers_.size(); ++i) { + if (weights_[i]->getWGrad() || this->needGradient_) { + int M = M_[i]; + int N = N_[i]; + int K = K_[i]; REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str()); - if (weights_[i]->getWGrad() || this->needGradient_) { - colBuf_->vol2Col(getOutputGrad()->getData() + n * outputWidth, - numFilters_, outputD_[i], outputH_[i], outputW_[i], - filterSizeZ_[i], filterSizeY_[i], filterSize_[i], - strideZ_[i], strideY_[i], stride_[i], - paddingZ_[i], paddingY_[i], padding_[i]); - } - if (weights_[i]->getWGrad()) { - real *inData = inMat->getData() + n * inMat->getWidth();; - real *wGradData = weights_[i]->getWGrad()->getData(); - for (int g = 0; g < groups_[i]; g++) { - MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K); - MatrixPtr inMatSub = Matrix::create( - inData, M, N, false, useGpu_); - MatrixPtr wGradMatSub = Matrix::create( - wGradData, K, M, false, useGpu_); - wGradMatSub->mul(*colBufDataSub, - *(inMatSub->getTranspose()), 1.0, 1.0); - wGradData += K * M; - inData += M * N; + Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); + const MatrixPtr &inMat = getInputValue(i); + for (int n = 0; n < batchSize; ++n) { + colBuf_->vol2Col( + getOutputGrad()->getData() + n * getOutputGrad()->getStride(), + numFilters_, + outputD_[i], + outputH_[i], + outputW_[i], + filterSizeZ_[i], + filterSizeY_[i], + filterSize_[i], + strideZ_[i], + strideY_[i], + stride_[i], + paddingZ_[i], + paddingY_[i], + padding_[i]); + if (weights_[i]->getWGrad()) { + real *inData = inMat->getData() + n * inMat->getStride(); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K); + MatrixPtr wGradMatSub = + weights_[i]->getWGrad()->subMatrix(g * K, K); + MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_); + wGradMatSub->mul( + *colBufDataSub, *(inMatSub->getTranspose()), 1.0, 1.0); + inData += M * N; + } } - weights_[i]->getParameterPtr()->incUpdate(callback); - } - if (this->needGradient_) { - real* preGrad = getInputGrad(i)->getData(); - for (int g = 0; g < groups_[i]; ++g) { - MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K); - MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K); - MatrixPtr inGradMatSub = Matrix::create( - preGrad, M, N, false, useGpu_); - inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 0.0); - preGrad += M * N; + if (getInputGrad(i)) { + real *preGrad = + getInputGrad(i)->getData() + n * getInputGrad(i)->getStride(); + for (int g = 0; g < groups_[i]; ++g) { + MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K); + MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K); + MatrixPtr inGradMatSub = + Matrix::create(preGrad, M, N, false, useGpu_); + inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 1.0); + preGrad += M * N; + } } } REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + weights_[i]->getParameterPtr()->incUpdate(callback); } } } - -void DeConv3DLayer::bpropWeights(int i) { } -void DeConv3DLayer::bpropData(int i) { } +void DeConv3DLayer::bpropWeights(int i) {} +void DeConv3DLayer::bpropData(int i) {} void DeConv3DLayer::bpropBiases() { - MatrixPtr outGradMat = getOutputGrad(); + const MatrixPtr &outGradMat = getOutputGrad(); if (this->sharedBiases_) { biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); From 43f6cdc8247042244f9b75bac51957c962a16ffd Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 18 Aug 2017 14:13:25 +0800 Subject: [PATCH 10/20] fix Matrix --- paddle/math/Matrix.cpp | 110 +++++++++++++++++++------------- paddle/math/Matrix.h | 140 ++++++++++++++++++++++++++--------------- 2 files changed, 153 insertions(+), 97 deletions(-) diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 66868e73b33a2..579a0f3cf32b8 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1389,51 +1389,71 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) { output_d, grad_d, mat_d, height_, width_); } -void GpuMatrix::vol2Col(real* data, - int channels, - int depth, - int height, - int width, - int filterD, - int filterH, - int filterW, - int strideD, - int strideH, - int strideW, - int paddingD, - int paddingH, - int paddingW) { - hl_matrix_vol2Col(data, - channels, depth, height, width, - filterD, filterH, filterW, - strideD, strideH, strideW, - paddingD, paddingH, paddingW, getData()); -} - -void GpuMatrix::col2Vol(real* trg, - int channels, - int depth, - int height, - int width, - int filterD, - int filterH, - int filterW, - int strideD, - int strideH, - int strideW, - int paddingD, - int paddingH, - int paddingW, - real alpha, - real beta) { - hl_matrix_col2Vol(trg, - channels, depth, height, width, - filterD, filterH, filterW, - strideD, strideH, strideW, - paddingD, paddingH, paddingW, +void GpuMatrix::vol2Col(real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + hl_matrix_vol2Col(dataSrc, + channels, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + getData()); +} + +void GpuMatrix::col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + hl_matrix_col2Vol(dataDst, + channels, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, getData(), - alpha, beta); - } + alpha, + beta); +} /** * CpuMatrix @@ -4082,7 +4102,7 @@ void CpuMatrix::col2Vol(real* trg, real alpha, real beta) { real* src = getData(); - int outDepth = (depth + 2 * paddingH - filterD) / strideD + 1; + int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1; int outHeight = (height + 2 * paddingH - filterH) / strideH + 1; int outWidth = (width + 2 * paddingW - filterW) / strideW + 1; int channelsCol = channels * filterD * filterH * filterW; diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 4354996ce0b02..cc3a56f279cc6 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1040,40 +1040,40 @@ class Matrix : public BaseMatrix { } virtual void vol2Col(real* data, - int channels, - int depth, - int height, - int width, - int filterD, - int filterH, - int filterW, - int strideD, - int strideH, - int strideW, - int paddingD, - int paddingH, - int paddingW) { - LOG(FATAL) << "Not implemeted"; - } + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW) { + LOG(FATAL) << "Not implemeted"; + } - virtual void col2Vol(real* trg, - int channels, - int depth, - int height, - int width, - int filterD, - int filterH, - int filterW, - int strideD, - int strideH, - int strideW, - int paddingD, - int paddingH, - int paddingW, - real alpha, - real beta) { - LOG(FATAL) << "Not implemeted"; - } + virtual void col2Vol(real* trg, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta) { + LOG(FATAL) << "Not implemeted"; + } virtual void bilinearForward(const Matrix& in, const size_t inImgH, @@ -1411,18 +1411,36 @@ class GpuMatrix : public Matrix { const real ratioW); void vol2Col(real* data, - int channels, - int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW); + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW); void col2Vol(real* trg, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real alpha, real beta); + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta); void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label); @@ -1767,17 +1785,35 @@ class CpuMatrix : public Matrix { void vol2Col(real* data, int channels, - int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW); + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW); void col2Vol(real* trg, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real alpha, real beta); + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real alpha, + real beta); template void operator=(const ExpressionType& expr) { From 0a7516d193061ccb35ab410fc947bd245a936159 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 18 Aug 2017 14:14:27 +0800 Subject: [PATCH 11/20] fix col2vol vol2col kernel --- paddle/cuda/src/hl_cuda_matrix.cu | 192 ++++++++++++++++++++---------- 1 file changed, 129 insertions(+), 63 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index f626c07a0c39a..3bf1b0251f3f0 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -593,21 +593,28 @@ void hl_matrix_rotate( CHECK_SYNC("hl_matrix_rotate failed"); } - -__global__ void keMatrixVol2Col( - int num_kernels, real*dataSrc, real* dataDst, - int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - int depth_col, int height_col, int width_col){ - - for (int index = blockIdx.x * blockDim.x + threadIdx.x; - index < num_kernels; - index += blockDim.x * gridDim.x){ - +__global__ void keMatrixVol2Col(int num_kernels, + real* dataSrc, + real* dataDst, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + int depth_col, + int height_col, + int width_col) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { int w_out = index % width_col; - int h_out = (index / width_col ) % height_col; + int h_out = (index / width_col) % height_col; int d_out = (index / width_col / height_col) % depth_col; int channel_in = index / width_col / height_col / depth_col; int channel_out = channel_in * filterD * filterH * filterW; @@ -615,7 +622,9 @@ __global__ void keMatrixVol2Col( int h_in = h_out * strideH - paddingH; int d_in = d_out * strideD - paddingD; - dataDst += ((channel_out * depth_col + d_out) * height_col + h_out) * width_col + w_out; + dataDst += + ((channel_out * depth_col + d_out) * height_col + h_out) * width_col + + w_out; dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in; for (int k = 0; k < filterD; ++k) { for (int i = 0; i < filterH; ++i) { @@ -623,8 +632,10 @@ __global__ void keMatrixVol2Col( int d = d_in + k; int h = h_in + i; int w = w_in + j; - *dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width ) ? - dataSrc[(k * height + i) * width + j] : 0; + *dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && + w < width) + ? dataSrc[(k * height + i) * width + j] + : 0; dataDst += depth_col * height_col * width_col; } } @@ -633,11 +644,20 @@ __global__ void keMatrixVol2Col( } void hl_matrix_vol2Col(real* dataSrc, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, real* dataDst){ - + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst) { int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; int height_col = (height + 2 * paddingH - filterH) / strideH + 1; int width_col = (width + 2 * paddingW - filterW) / strideW + 1; @@ -646,34 +666,55 @@ void hl_matrix_vol2Col(real* dataSrc, const int threads = 512; const int blocks = DIVUP(num_kernels, threads); - keMatrixVol2Col<<< blocks, threads >>>( - num_kernels, dataSrc, dataDst, - depth, height, width, - filterD, filterH, filterW, - strideD, strideH, strideW, - paddingD, paddingH, paddingW, - depth_col, height_col, width_col); + keMatrixVol2Col<<>>(num_kernels, + dataSrc, + dataDst, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col); CHECK_SYNC("hl_matrix_vol2Col failed"); } -__global__ void keMatrixCol2Vol( - int num_kernels, real*dataDst, real* dataSrc, - int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - int depth_col, int height_col, int width_col, - real alpha, real beta){ - - for (int index = blockIdx.x * blockDim.x + threadIdx.x; - index < num_kernels; +__global__ void keMatrixCol2Vol(int num_kernels, + real* dataDst, + real* dataSrc, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + int depth_col, + int height_col, + int width_col, + real alpha, + real beta) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { - - real val = 0; + real srcVal = 0; + real dstVal = dataDst[index]; int w = index % width + paddingW; int h = (index / width) % height + paddingH; int d = (index / width / height) % depth + paddingD; - int c = index / (width * height * depth); + int c = index / width / height / depth; // compute the start and end of the output int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1; int w_col_end = min(w / strideW + 1, width_col); @@ -682,32 +723,45 @@ __global__ void keMatrixCol2Vol( int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1; int d_col_end = min(d / strideD + 1, depth_col); - int offset = (c * filterD * filterW * filterH + \ - d * filterW * filterH + h * filterW + w) * depth_col * height_col * width_col; + int offset = (c * filterD * filterW * filterH + d * filterW * filterH + + h * filterW + w) * + depth_col * height_col * width_col; - int coeff_d_col = (1 - strideD * filterW * filterH * depth_col) * height_col * width_col; - int coeff_h_col = (1 - strideH * filterW * depth_col * height_col) * width_col; + int coeff_d_col = + (1 - strideD * filterW * filterH * depth_col) * height_col * width_col; + int coeff_h_col = + (1 - strideH * filterW * depth_col * height_col) * width_col; int coeff_w_col = (1 - strideW * depth_col * height_col * width_col); for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - val += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col + w_col * coeff_w_col]; + srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col + + w_col * coeff_w_col]; } } } - dataDst[index] = val; + dataDst[index] = alpha * srcVal + beta * dstVal; } } void hl_matrix_col2Vol(real* dataDst, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, real* dataSrc, - real alpha, real beta){ - + real alpha, + real beta) { int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; int height_col = (height + 2 * paddingH - filterH) / strideH + 1; int width_col = (width + 2 * paddingW - filterW) / strideW + 1; @@ -716,14 +770,26 @@ void hl_matrix_col2Vol(real* dataDst, const int threads = 512; const int blocks = DIVUP(num_kernels, threads); - keMatrixCol2Vol<<< blocks, threads >>>( - num_kernels, dataDst, dataSrc, - depth, height, width, - filterD, filterH, filterW, - strideD, strideH, strideW, - paddingD, paddingH, paddingW, - depth_col, height_col, width_col, - alpha, beta); + keMatrixCol2Vol<<>>(num_kernels, + dataDst, + dataSrc, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col, + alpha, + beta); CHECK_SYNC("hl_matrix_col2Vol failed"); } From 38cc5dadcc5c76c4aa50f5e92b560f4ccaba9227 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 18 Aug 2017 16:43:59 +0800 Subject: [PATCH 12/20] modified bias shape of ConvLayer --- paddle/gserver/layers/Conv3DLayer.cpp | 5 ----- paddle/gserver/layers/ConvBaseLayer.cpp | 17 ++++++++--------- paddle/gserver/layers/DeConv3DLayer.cpp | 5 ----- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp index 5609a4cc73562..106909824dfff 100644 --- a/paddle/gserver/layers/Conv3DLayer.cpp +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -37,11 +37,6 @@ bool Conv3DLayer::init(const LayerMap &layerMap, weights_[index]->getWGrad()->getHeight()); ++index; } - if (nullptr != biases_->getWGrad()) - biases_->getWGrad()->reshape(biases_->getWGrad()->width_, - biases_->getWGrad()->height_); - if (nullptr != biases_->getW()) - biases_->getW()->reshape(biases_->getW()->width_, biases_->getW()->height_); CHECK(inputLayers_.size() == parameters_.size()); return true; } diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index e437b0b86e0de..6bcbe0ddb2d3a 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -21,11 +21,10 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - isDeconv_ = (config_.type() == "exconv" || - config_.type() == "cudnn_conv" || - config_.type() == "conv3d" || - config_.type() == "deconv3d" ) - ? false : true; + isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv" || + config_.type() == "conv3d" || config_.type() == "deconv3d") + ? false + : true; /* Initialize the convolutional layer parameter */ numFilters_ = config_.num_filters(); @@ -52,8 +51,8 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, filterSizeZ_.push_back(conf.filter_size_z()); imgSizeD_.push_back(conf.img_size_z()); outputD_.push_back(conf.output_z()); - filterPixels_.push_back( - filterSize_.back() * filterSizeY_.back() * filterSizeZ_.back()); + filterPixels_.push_back(filterSize_.back() * filterSizeY_.back() * + filterSizeZ_.back()); } CHECK(inputLayers_.size() == parameters_.size()); @@ -73,10 +72,10 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, if (sharedBiases_) { CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); biases_ = - std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); + std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); } else { biases_ = - std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); + std::unique_ptr(new Weight(1, getSize(), biasParameter_)); } } diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp index 286f5b985c38e..5a54a684471cb 100644 --- a/paddle/gserver/layers/DeConv3DLayer.cpp +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -39,11 +39,6 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, weights_[index]->getWGrad()->reshape(filterPixels_[index] * numFilters_, filterChannels_[index]); } - if (biases_->getWGrad()) - biases_->getWGrad()->reshape(biases_->getWGrad()->width_, - biases_->getWGrad()->height_); - if (biases_->getW()) - biases_->getW()->reshape(biases_->getW()->width_, biases_->getW()->height_); CHECK(inputLayers_.size() == parameters_.size()); return true; } From d5768ebc89868431040e47e3db126263da385d70 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 18 Aug 2017 20:49:35 +0800 Subject: [PATCH 13/20] fix above comments --- paddle/cuda/include/hl_matrix.h | 58 ++++++++----- paddle/cuda/include/stub/hl_matrix_stub.h | 47 +++++++---- paddle/cuda/src/hl_cuda_matrix.cu | 84 +++++++++---------- paddle/gserver/layers/Conv3DLayer.cpp | 26 ++++-- paddle/gserver/layers/Conv3DLayer.h | 14 +--- paddle/gserver/layers/ConvBaseLayer.cpp | 26 +----- paddle/gserver/layers/ConvBaseLayer.h | 1 - paddle/gserver/layers/CudnnConvBaseLayer.cpp | 18 ++++ paddle/gserver/layers/DeConv3DLayer.cpp | 46 +++++----- paddle/gserver/layers/DeConv3DLayer.h | 44 +++++----- paddle/gserver/layers/ExpandConvBaseLayer.cpp | 21 ++++- paddle/gserver/tests/test_LayerGrad.cpp | 31 +++---- paddle/math/tests/test_matrixCompare.cpp | 28 ++----- proto/ModelConfig.proto | 4 +- 14 files changed, 247 insertions(+), 201 deletions(-) diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index da2ed8cabb766..a37921b7493e3 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -240,16 +240,25 @@ extern void hl_matrix_rotate( * @param[in] strideW stride in the width. * @param[in] paddingD padding in the depth. * @param[in] paddingH padding in the height. - * @param[in] paddingW padding in the width. + * @param[in] paddingW padding in the width. * @param[out] matDst output matrix. - * + * */ -extern void hl_matrix_vol2Col(real* matSrc, - int channel, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* matDst); +extern void hl_matrix_vol2Col(const real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst); /** * @brief Matrix col2Vol: Convert col matrix into 3D volume @@ -267,19 +276,28 @@ extern void hl_matrix_vol2Col(real* matSrc, * @param[in] strideW stride in the width. * @param[in] paddingD padding in the depth. * @param[in] paddingH padding in the height. - * @param[in] paddingW padding in the width. + * @param[in] paddingW padding in the width. * @param[in] matSrc input matrix. - * @param[in] beta input - * @param[in] alpha input - * + * @param[in] beta input + * @param[in] alpha input + * */ -extern void hl_matrix_col2Vol(real* matDst, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* matSrc, - real alpha, real beta); - +extern void hl_matrix_col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + const real* dataSrc, + real alpha, + real beta); #endif /* HL_MATRIX_H_ */ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index 0b73777812ae6..6ac332945c8f0 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -99,19 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d, inline void hl_matrix_rotate( real* mat, real* matRot, int dimM, int dimN, bool clockWise) {} -inline void hl_matrix_vol2Col(real* data, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* data_col) {} - -inline void hl_matrix_col2Vol(real* data, - int channels, int depth, int height, int width, - int filterD, int filterH, int filterW, - int strideD, int strideH, int strideW, - int paddingD, int paddingH, int paddingW, - real* data_Im, - real alpha, real beta) {} +inline void hl_matrix_vol2Col(const real* dataSrc, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + real* dataDst) {} + +inline void hl_matrix_col2Vol(real* dataDst, + int channels, + int depth, + int height, + int width, + int filterD, + int filterH, + int filterW, + int strideD, + int strideH, + int strideW, + int paddingD, + int paddingH, + int paddingW, + const real* dataSrc, + real alpha, + real beta) {} #endif // HL_MATRIX_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 3bf1b0251f3f0..b41a3a1e06db7 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -594,7 +594,7 @@ void hl_matrix_rotate( } __global__ void keMatrixVol2Col(int num_kernels, - real* dataSrc, + const real* dataSrc, real* dataDst, int depth, int height, @@ -643,7 +643,7 @@ __global__ void keMatrixVol2Col(int num_kernels, } } -void hl_matrix_vol2Col(real* dataSrc, +void hl_matrix_vol2Col(const real* dataSrc, int channels, int depth, int height, @@ -666,30 +666,30 @@ void hl_matrix_vol2Col(real* dataSrc, const int threads = 512; const int blocks = DIVUP(num_kernels, threads); - keMatrixVol2Col<<>>(num_kernels, - dataSrc, - dataDst, - depth, - height, - width, - filterD, - filterH, - filterW, - strideD, - strideH, - strideW, - paddingD, - paddingH, - paddingW, - depth_col, - height_col, - width_col); + keMatrixVol2Col<<>>(num_kernels, + dataSrc, + dataDst, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col); CHECK_SYNC("hl_matrix_vol2Col failed"); } __global__ void keMatrixCol2Vol(int num_kernels, real* dataDst, - real* dataSrc, + const real* dataSrc, int depth, int height, int width, @@ -759,7 +759,7 @@ void hl_matrix_col2Vol(real* dataDst, int paddingD, int paddingH, int paddingW, - real* dataSrc, + const real* dataSrc, real alpha, real beta) { int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; @@ -770,26 +770,26 @@ void hl_matrix_col2Vol(real* dataDst, const int threads = 512; const int blocks = DIVUP(num_kernels, threads); - keMatrixCol2Vol<<>>(num_kernels, - dataDst, - dataSrc, - depth, - height, - width, - filterD, - filterH, - filterW, - strideD, - strideH, - strideW, - paddingD, - paddingH, - paddingW, - depth_col, - height_col, - width_col, - alpha, - beta); + keMatrixCol2Vol<<>>(num_kernels, + dataDst, + dataSrc, + depth, + height, + width, + filterD, + filterH, + filterW, + strideD, + strideH, + strideW, + paddingD, + paddingH, + paddingW, + depth_col, + height_col, + width_col, + alpha, + beta); CHECK_SYNC("hl_matrix_col2Vol failed"); } diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp index 106909824dfff..db907bbab1c28 100644 --- a/paddle/gserver/layers/Conv3DLayer.cpp +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -28,16 +28,26 @@ bool Conv3DLayer::init(const LayerMap &layerMap, const ConvConfig &conf = inputConfig.conv_conf(); M_.push_back(numFilters_ / conf.groups()); K_.push_back(filterPixels_[index] * filterChannels_[index]); - if (nullptr != weights_[index]->getW()) - weights_[index]->getW()->reshape(weights_[index]->getW()->getWidth(), - weights_[index]->getW()->getHeight()); - if (nullptr != weights_[index]->getWGrad()) - weights_[index]->getWGrad()->reshape( - weights_[index]->getWGrad()->getWidth(), - weights_[index]->getWGrad()->getHeight()); + + // create a new weight + size_t height, width; + width = filterPixels_[index] * filterChannels_[index]; + height = numFilters_; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); ++index; } - CHECK(inputLayers_.size() == parameters_.size()); + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + } + } return true; } diff --git a/paddle/gserver/layers/Conv3DLayer.h b/paddle/gserver/layers/Conv3DLayer.h index 703671e5d0dcb..b622508d0ce1b 100644 --- a/paddle/gserver/layers/Conv3DLayer.h +++ b/paddle/gserver/layers/Conv3DLayer.h @@ -12,13 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once - +#include #include "ConvBaseLayer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/MathUtils.h" -#include +#include "paddle/math/Matrix.h" namespace paddle { @@ -30,21 +28,17 @@ namespace paddle { class Conv3DLayer : public ConvBaseLayer { public: explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} - ~Conv3DLayer() {} - bool init(const LayerMap &layerMap, const ParameterMap ¶meterMap); - - size_t getSize(); + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); void forward(PassType passType); void addBias(); - void backward(const UpdateCallback& callback); - void bpropBiases(); void bpropData(int i); void bpropWeights(int i); + size_t getSize(); protected: // Figure out the dimensions for individual gemms. diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index 6bcbe0ddb2d3a..8c637eaec93d5 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -21,8 +21,7 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv" || - config_.type() == "conv3d" || config_.type() == "deconv3d") + isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv") ? false : true; @@ -56,28 +55,9 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, } CHECK(inputLayers_.size() == parameters_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - size_t height, width; - height = filterPixels_[i] * filterChannels_[i]; - width = (!isDeconv_) ? numFilters_ : channels_[i]; - - // create a new weight - CHECK_EQ(parameters_[i]->getSize(), width * height); - Weight* w = new Weight(height, width, parameters_[i]); - weights_.emplace_back(w); - } - /* initialize the biases_ */ - if (biasParameter_.get()) { - if (sharedBiases_) { - CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); - biases_ = - std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); - } else { - biases_ = - std::unique_ptr(new Weight(1, getSize(), biasParameter_)); - } - } + // create new weights_ in derived class + // create new biases_ in derived class // default caffe model caffeMode_ = true; diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h index 8d1fd989e8383..629c462776d74 100644 --- a/paddle/gserver/layers/ConvBaseLayer.h +++ b/paddle/gserver/layers/ConvBaseLayer.h @@ -23,7 +23,6 @@ namespace paddle { * with learned filters and (optionally) adds biases. */ - class ConvBaseLayer : public Layer { protected: typedef std::vector IntV; diff --git a/paddle/gserver/layers/CudnnConvBaseLayer.cpp b/paddle/gserver/layers/CudnnConvBaseLayer.cpp index c056bbe4d1d35..9e954615cddf2 100644 --- a/paddle/gserver/layers/CudnnConvBaseLayer.cpp +++ b/paddle/gserver/layers/CudnnConvBaseLayer.cpp @@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap, projConf_.emplace_back(conf); projections_.emplace_back( Projection::create(*projConf_[i], parameters_[i], useGpu_)); + + // create a new weight + size_t height, width; + height = filterPixels_[i] * filterChannels_[i]; + width = (!isDeconv_) ? numFilters_ : channels_[i]; + CHECK_EQ(parameters_[i]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[i]); + weights_.emplace_back(w); } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); + } + } if (biases_.get() && sharedBiases_) { hl_create_tensor_descriptor(&biasDesc_); hl_create_tensor_descriptor(&outputDesc_); diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp index 5a54a684471cb..b18c06e36c897 100644 --- a/paddle/gserver/layers/DeConv3DLayer.cpp +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -20,9 +20,6 @@ namespace paddle { REGISTER_LAYER(deconv3d, DeConv3DLayer); -#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \ - (((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE)) - bool DeConv3DLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; @@ -32,14 +29,25 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, for (int index = 0; index < config_.inputs().size(); ++index) { M_.push_back(filterChannels_[index]); K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index])); - if (weights_[index]->getW()) - weights_[index]->getW()->reshape(filterPixels_[index] * numFilters_, - filterChannels_[index]); - if (weights_[index]->getWGrad()) - weights_[index]->getWGrad()->reshape(filterPixels_[index] * numFilters_, - filterChannels_[index]); + + // create a new weight + size_t height, width; + height = filterPixels_[index] * numFilters_; + width = filterChannels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + } } - CHECK(inputLayers_.size() == parameters_.size()); return true; } @@ -52,22 +60,22 @@ size_t DeConv3DLayer::getSize() { outputW_.clear(); outputD_.clear(); N_.clear(); - No_.clear(); + NOut_.clear(); size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); ++i) { // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); - outputW_.push_back(DECONV_OUTPUT_SIZE( - imgSizeW_[i], stride_[i], padding_[i], filterSize_[i])); - outputH_.push_back(DECONV_OUTPUT_SIZE( - imgSizeH_[i], strideY_[i], paddingY_[i], filterSizeY_[i])); - outputD_.push_back(DECONV_OUTPUT_SIZE( - imgSizeD_[i], strideZ_[i], paddingZ_[i], filterSizeZ_[i])); - No_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); + outputW_.push_back( + imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); + outputH_.push_back(imageSize( + imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); + outputD_.push_back(imageSize( + imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); + NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); - layerSize += No_[i] * numFilters_; + layerSize += NOut_[i] * numFilters_; } getOutput().setFrameHeight(outputH_[0]); getOutput().setFrameWidth(outputW_[0]); diff --git a/paddle/gserver/layers/DeConv3DLayer.h b/paddle/gserver/layers/DeConv3DLayer.h index 435807fe5defa..a2a3d3f8273ed 100644 --- a/paddle/gserver/layers/DeConv3DLayer.h +++ b/paddle/gserver/layers/DeConv3DLayer.h @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once +#include #include "ConvBaseLayer.h" -#include "paddle/math/Matrix.h" #include "paddle/math/MathUtils.h" -#include +#include "paddle/math/Matrix.h" namespace paddle { @@ -29,30 +28,25 @@ namespace paddle { */ class DeConv3DLayer : public ConvBaseLayer { public: - explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} - - ~DeConv3DLayer() {} - - bool init(const LayerMap &layerMap, const ParameterMap ¶meterMap); - - size_t getSize(); - - void forward(PassType passType); - void addBias(); - - void backward(const UpdateCallback& callback); - - void bpropBiases(); - void bpropData(int i); - void bpropWeights(int i); + explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} + ~DeConv3DLayer() {} + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + + void forward(PassType passType); + void addBias(); + void backward(const UpdateCallback& callback); + void bpropBiases(); + void bpropData(int i); + void bpropWeights(int i); + size_t getSize(); protected: - // Figure out the dimensions for individual gemms. - IntV M_; /// numFilters_ / filter_group_; - IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ - IntV K_; /// outputD_ * outputH_ * outputW_ - IntV No_; - MatrixPtr colBuf_; + // Figure out the dimensions for individual gemms. + IntV M_; /// numFilters_ / filter_group_; + IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ + IntV K_; /// outputD_ * outputH_ * outputW_ + IntV NOut_; + MatrixPtr colBuf_; }; } // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvBaseLayer.cpp b/paddle/gserver/layers/ExpandConvBaseLayer.cpp index 77736e78f9349..2b7bef0a757d7 100644 --- a/paddle/gserver/layers/ExpandConvBaseLayer.cpp +++ b/paddle/gserver/layers/ExpandConvBaseLayer.cpp @@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, /* Initialize the basic convolutional parent class */ ConvBaseLayer::init(layerMap, parameterMap); + int index = 0; for (auto &inputConfig : config_.inputs()) { const ConvConfig &conf = inputConfig.conv_conf(); /* Consistent caffe mode for multiple input */ caffeMode_ = conf.caffe_mode(); - } + // create a new weight + size_t height, width; + height = filterPixels_[index] * filterChannels_[index]; + width = (!isDeconv_) ? numFilters_ : channels_[index]; + CHECK_EQ(parameters_[index]->getSize(), width * height); + Weight *w = new Weight(height, width, parameters_[index]); + weights_.emplace_back(w); + index++; + } + if (biasParameter_.get()) { + if (sharedBiases_) { + CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); + biases_ = + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); + } else { + biases_ = + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); + } + } getOutputSize(); return true; diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 1e80e2c0ee0f8..d5724293bf828 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2019,7 +2019,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) { const int CHANNELS = 3; const int IMAGE_SIZE = 9; const int IMAGE_SIZE_Y = 9; - const int IMAGE_SIZE_Z = 9; // 2, 3, 5, 5, 5 + const int IMAGE_SIZE_Z = 9; TestConfig config; config.biasSize = NUM_FILTERS; @@ -2084,10 +2084,6 @@ TEST(Layer, test3DConvLayer) { #endif } -int deConvOutputSize(int inSize, int kSize, int pad, int stride) { - return (inSize - 1) * stride - 2 * pad + kSize; -} - void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { // filter size const int NUM_FILTERS = 6; @@ -2126,16 +2122,21 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { conv->set_img_size(IMAGE_SIZE); conv->set_img_size_y(IMAGE_SIZE_Y); conv->set_img_size_z(IMAGE_SIZE_Z); - conv->set_output_x(deConvOutputSize( - conv->img_size(), conv->filter_size(), conv->padding(), conv->stride())); - conv->set_output_y(deConvOutputSize(conv->img_size_y(), - conv->filter_size_y(), - conv->padding_y(), - conv->stride_y())); - conv->set_output_z(deConvOutputSize(conv->img_size_z(), - conv->filter_size_z(), - conv->padding_z(), - conv->stride_z())); + conv->set_output_x(imageSize(conv->img_size(), + conv->filter_size(), + conv->padding(), + conv->stride(), + true)); + conv->set_output_y(imageSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + true)); + conv->set_output_z(imageSize(conv->img_size_z(), + conv->filter_size_z(), + conv->padding_z(), + conv->stride_z(), + true)); config.layerConfig.set_size(conv->output_x() * conv->output_y() * conv->output_z() * NUM_FILTERS); conv->set_groups(1); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 1d41ec087028a..3abe4484dbc86 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "TensorCheck.h" +#include "paddle/math/MathUtils.h" #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h" #include "paddle/testing/TestUtil.h" @@ -1203,19 +1204,6 @@ TEST(Matrix, warpCTC) { } } -int outputSizeCol2Vol( - int imageSize, int filterSize, int padding, int stride, bool caffeMode) { - int outputSize; - if (!caffeMode) { - outputSize = - (imageSize - filterSize + 2 * padding + stride - 1) / stride + 1; - } else { - outputSize = (imageSize - filterSize + 2 * padding) / stride + 1; - } - CHECK_GE(outputSize, 1); - return outputSize; -} - void testMatrixCol2Vol(int depth, int height, int width) { int channel = 3; int filterX = 3, filterY = 4, filterZ = 5; @@ -1229,9 +1217,9 @@ void testMatrixCol2Vol(int depth, int height, int width) { cpuImage->randomizeUniform(); gpuImage->copyFrom(*cpuImage); - int outD = outputSizeCol2Vol(depth, filterZ, padZ, strideZ, true); - int outH = outputSizeCol2Vol(height, filterY, padZ, strideY, true); - int outW = outputSizeCol2Vol(width, filterX, padZ, strideX, true); + int outD = outputSize(depth, filterZ, padZ, strideZ, true); + int outH = outputSize(height, filterY, padY, strideY, true); + int outW = outputSize(width, filterX, padX, strideX, true); int colBufHeight = channel * filterZ * filterY * filterX; int colBufWidth = outD * outH * outW; @@ -1305,11 +1293,9 @@ void testMatrixCol2Vol(int depth, int height, int width) { } TEST(Matrix, col2Vol) { - for (auto depth : {9, 16, 64, 128}) { - for (auto height : {9, 11, 73, 128, 256}) { - for (auto width : { - 9, 32, 100, 512, - }) { + for (auto depth : {9, 16, 64}) { + for (auto height : {9, 11, 128}) { + for (auto width : {9, 32, 128}) { VLOG(3) << "depth=" << depth << " height=" << height << " width=" << width; testMatrixCol2Vol(depth, height, width); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 043ae502b0221..8c6eb5b7e1717 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -82,7 +82,7 @@ message ConvConfig { // if not set, use img_size optional uint32 img_size_y = 14; - + optional uint32 filter_size_z = 15 [ default = 1 ]; optional uint32 padding_z = 16 [ default = 1 ]; optional uint32 stride_z = 17 [ default = 1 ]; @@ -637,4 +637,4 @@ message ModelConfig { // For External Machine, defining how to split a neural network // into multiple parts. optional ExternalConfig external_config = 9; -}; \ No newline at end of file +}; From f715c740bf2bfedb779ba4876f4d6b16e770e61d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 21 Aug 2017 23:07:51 +0800 Subject: [PATCH 14/20] Add_config_parser_for_Conv3D_DeConv3D --- proto/ModelConfig.proto | 1 + python/paddle/trainer/config_parser.py | 266 ++++++++++++++- python/paddle/trainer/recurrent_units.py | 0 .../paddle/trainer_config_helpers/layers.py | 316 ++++++++++++------ .../paddle/trainer_config_helpers/networks.py | 4 +- .../configs/conv3d_deconv3d_test_config.py | 98 ++++++ .../tests/layers_test.py | 4 +- 7 files changed, 581 insertions(+), 108 deletions(-) mode change 100755 => 100644 python/paddle/trainer/recurrent_units.py mode change 100755 => 100644 python/paddle/trainer_config_helpers/layers.py mode change 100755 => 100644 python/paddle/trainer_config_helpers/networks.py create mode 100644 python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 8c6eb5b7e1717..21049ba0a0269 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -489,6 +489,7 @@ message LayerConfig { // to indicate rectangle image data optional uint64 height = 50; optional uint64 width = 51; + optional uint64 depth = 57 [ default = 1 ]; // blank label used in ctc loss optional uint32 blank = 52 [ default = 0 ]; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index b7b696ef0c13e..49b3c430e7050 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -881,6 +881,42 @@ def __init__(self, config_assert(output_x <= 0) +# please refer to the comments in proto/ModelConfig.proto +@config_class +class Conv3D(Cfg): + def __init__(self, + filter_size, + channels, + padding=None, + stride=None, + groups=None, + filter_channels=None, + output_x=None, + img_size=None, + caffe_mode=True, + filter_size_y=None, + padding_y=None, + stride_y=None, + filter_size_z=None, + padding_z=None, + stride_z=None): + self.add_keys(locals()) + if filter_size_y is None: + self.filter_size_y = filter_size + if padding_y is None: + self.padding_y = padding + if stride_y is None: + self.stride_y = stride + if output_x is not None: + config_assert(output_x <= 0) + if filter_size_z is None: + self.filter_size_z = filter_size + if padding_z is None: + self.padding_z = padding + if stride_z is None: + self.stride_z = stride + + @config_class class BilinearInterp(Cfg): def __init__(self, out_size_x=None, out_size_y=None, channels=None): @@ -1167,6 +1203,20 @@ def get_img_size(input_layer_name, channels): return img_size, img_size_y +def get_img3d_size(input_layer_name, channels): + input = g_layer_map[input_layer_name] + img_pixels = input.size / channels + img_size = input.width if input.width > 0 else int(img_pixels**0.5) + img_size_y = input.height if input.height > 0 else int(img_pixels / + img_size) + img_size_z = input.depth if input.depth > 1 else 1 + config_assert( + img_size * img_size_y * img_size_z == img_pixels, + "Input layer %s: Incorrect input image size %d * %d * %d for input image pixels %d" + % (input_layer_name, img_size, img_size_y, img_size_z, img_pixels)) + return img_size, img_size_y, img_size_z + + def parse_bilinear(bilinear, input_layer_name, bilinear_conf): parse_image(bilinear, input_layer_name, bilinear_conf.image_conf) bilinear_conf.out_size_x = bilinear.out_size_x @@ -1277,6 +1327,50 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False): conv_conf.stride_y, conv_conf.caffe_mode) +#caffe_mode: compute the output size using floor instead of ceil, +# which is consistent of caffe and CuDNN's convention. +def parse_conv3d(conv, input_layer_name, conv_conf, num_filters, trans=False): + conv_conf.filter_size = conv.filter_size + conv_conf.filter_size_y = conv.filter_size_y + conv_conf.filter_size_z = conv.filter_size_z + conv_conf.channels = conv.channels + conv_conf.padding = conv.padding + conv_conf.padding_y = conv.padding_y + conv_conf.padding_z = conv.padding_z + conv_conf.stride = conv.stride + conv_conf.stride_y = conv.stride_y + conv_conf.stride_z = conv.stride_z + conv_conf.groups = conv.groups + conv_conf.caffe_mode = conv.caffe_mode + + if not trans: + conv_conf.filter_channels = conv.channels / conv.groups + conv_conf.img_size, conv_conf.img_size_y, conv_conf.img_size_z = \ + get_img3d_size(input_layer_name, conv.channels) + conv_conf.output_x = cnn_output_size( + conv_conf.img_size, conv_conf.filter_size, conv_conf.padding, + conv_conf.stride, conv_conf.caffe_mode) + conv_conf.output_y = cnn_output_size( + conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y, + conv_conf.stride_y, conv_conf.caffe_mode) + conv_conf.output_z = cnn_output_size( + conv_conf.img_size_z, conv_conf.filter_size_z, conv_conf.padding_z, + conv_conf.stride_z, conv_conf.caffe_mode) + else: + conv_conf.filter_channels = num_filters / conv.groups + conv_conf.output_x, conv_conf.output_y, conv_conf.output_z = \ + get_img3d_size(input_layer_name, conv.channels) + conv_conf.img_size = cnn_image_size( + conv_conf.output_x, conv_conf.filter_size, conv_conf.padding, + conv_conf.stride, conv_conf.caffe_mode) + conv_conf.img_size_y = cnn_image_size( + conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y, + conv_conf.stride_y, conv_conf.caffe_mode) + conv_conf.img_size_z = cnn_image_size( + conv_conf.output_z, conv_conf.filter_size_z, conv_conf.padding_z, + conv_conf.stride_z, conv_conf.caffe_mode) + + def parse_block_expand(block_expand, input_layer_name, block_expand_conf): block_expand_conf.channels = block_expand.channels block_expand_conf.stride_x = block_expand.stride_x @@ -1580,6 +1674,9 @@ def set_layer_height_width(self, height, width): self.config.height = height self.config.width = width + def set_layer_depth(self, depth): + self.config.depth = depth + def set_cnn_layer(self, input_layer_name, height, @@ -1763,11 +1860,19 @@ def __init__(self, name, inputs, size, input_num, num_classes, @config_layer('data') class DataLayer(LayerBase): - def __init__(self, name, size, height=None, width=None, device=None): + def __init__(self, + name, + size, + height=None, + width=None, + depth=None, + device=None): super(DataLayer, self).__init__( name, 'data', size, inputs=[], device=device) if height and width: self.set_layer_height_width(height, width) + if depth: + self.set_layer_depth(depth) ''' @@ -1882,7 +1987,7 @@ def __init__(self, def calc_parameter_size(self, conv_conf): return self.config.num_filters * conv_conf.filter_channels \ - * (conv_conf.filter_size * conv_conf.filter_size_y) + * (conv_conf.filter_size * conv_conf.filter_size_y) @config_layer('exconv') @@ -1895,6 +2000,163 @@ class ConvLayer(ConvLayerBase): layer_type = 'cudnn_conv' +@config_layer('conv_3d') +class Conv3DLayerBase(LayerBase): + def __init__(self, + name, + inputs=[], + bias=True, + num_filters=None, + shared_biases=False, + **xargs): + super(Conv3DLayerBase, self).__init__( + name, self.layer_type, 0, inputs=inputs, **xargs) + + if num_filters is not None: + self.config.num_filters = num_filters + + use_gpu = int(g_command_config_args.get("use_gpu", 0)) + parallel_nn = int(g_command_config_args.get("parallel_nn", 0)) + + # Automatically select cudnn_type for GPU and exconv for CPU + # if set type=conv, but still reserve the way user specify + # exconv or cudnn_conv manually. + if self.layer_type == "cudnn_conv3d": + config_assert(use_gpu, "cudnn_conv3d only support GPU") + + # need to specify layer in config + self.config.type = self.layer_type + + if shared_biases is not None: + self.config.shared_biases = shared_biases + + for input_index in xrange(len(self.inputs)): + input_layer = self.get_input_layer(input_index) + conv_conf = self.config.inputs[input_index].conv_conf + parse_conv3d( + self.inputs[input_index].conv, input_layer.name, conv_conf, + num_filters + ) # for z-axis pad:0, strid:1, filter_size:1, img_size:1 + psize = self.calc_parameter_size(conv_conf) + self.create_input_parameter(input_index, psize) + self.set_cnn_layer(name, conv_conf.output_z, conv_conf.output_y, + conv_conf.output_x, self.config.num_filters) + + psize = self.config.size + if shared_biases: + psize = self.config.num_filters + self.create_bias_parameter(bias, psize, [psize, 1]) + + def calc_parameter_size(self, conv_conf): + return self.config.num_filters * conv_conf.filter_channels \ + * (conv_conf.filter_size * conv_conf.filter_size_y \ + * conv_conf.filter_size_z) + + def set_layer_height_width(self, depth, height, width): + self.config.depth = depth + self.config.height = height + self.config.width = width + + def set_cnn_layer(self, + input_layer_name, + depth, + height, + width, + channels, + is_print=True): + size = depth * height * width * channels + self.set_layer_size(size) + self.set_layer_height_width(depth, height, width) + if is_print: + print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" % + (input_layer_name, channels, depth, height, width, size)) + + +@config_layer('conv3d') +class Conv3DLayer(Conv3DLayerBase): + layer_type = 'conv3d' + + +@config_layer('convt_3d') +class Conv3DTransLayerBase(LayerBase): + def __init__(self, + name, + inputs=[], + bias=True, + num_filters=None, + shared_biases=False, + **xargs): + super(Conv3DTransLayerBase, self).__init__( + name, self.layer_type, 0, inputs=inputs, **xargs) + + if num_filters is not None: + self.config.num_filters = num_filters + + use_gpu = int(g_command_config_args.get("use_gpu", 0)) + parallel_nn = int(g_command_config_args.get("parallel_nn", 0)) + + # Automatically select cudnn_type for GPU and exconv for CPU + # if set type=conv, but still reserve the way user specify + # exconv or cudnn_conv manually. + if self.layer_type == "cudnn_deconv3d": + config_assert(use_gpu, "cudnn_conv3d only support GPU") + + # need to specify layer in config + self.config.type = self.layer_type + + if shared_biases is not None: + self.config.shared_biases = shared_biases + + for input_index in xrange(len(self.inputs)): + input_layer = self.get_input_layer(input_index) + conv_conf = self.config.inputs[input_index].conv_conf + parse_conv3d( + self.inputs[input_index].conv, + input_layer.name, + conv_conf, + num_filters, + trans=True + ) # for z-axis pad:0, strid:1, filter_size:1, img_size:1 + psize = self.calc_parameter_size(conv_conf) + self.create_input_parameter(input_index, psize) + self.set_cnn_layer(name, conv_conf.img_size_z, conv_conf.img_size_y, + conv_conf.img_size, self.config.num_filters) + + psize = self.config.size + if shared_biases: + psize = self.config.num_filters + self.create_bias_parameter(bias, psize, [psize, 1]) + + def calc_parameter_size(self, conv_conf): + return self.config.num_filters * conv_conf.filter_channels \ + * (conv_conf.filter_size * conv_conf.filter_size_y \ + * conv_conf.filter_size_z) + + def set_layer_height_width(self, depth, height, width): + self.config.depth = depth + self.config.height = height + self.config.width = width + + def set_cnn_layer(self, + input_layer_name, + depth, + height, + width, + channels, + is_print=True): + size = depth * height * width * channels + self.set_layer_size(size) + self.set_layer_height_width(depth, height, width) + if is_print: + print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" % + (input_layer_name, channels, depth, height, width, size)) + + +@config_layer('deconv3d') +class DeConv3DLayer(Conv3DTransLayerBase): + layer_type = 'deconv3d' + + @config_layer('convt') class ConvTransLayerBase(LayerBase): layer_type = 'convt' diff --git a/python/paddle/trainer/recurrent_units.py b/python/paddle/trainer/recurrent_units.py old mode 100755 new mode 100644 diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py old mode 100755 new mode 100644 index 1bc55c8696015..6953f134c5dac --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -31,108 +31,34 @@ import copy __all__ = [ - 'full_matrix_projection', - 'AggregateLevel', - 'ExpandLevel', - 'identity_projection', - 'dotmul_projection', - 'dotmul_operator', - 'repeat_layer', - 'seq_reshape_layer', - 'table_projection', - 'mixed_layer', - 'data_layer', - 'embedding_layer', - 'fc_layer', - 'grumemory', - 'pooling_layer', - 'lstmemory', - 'last_seq', - 'first_seq', - 'cos_sim', - 'hsigmoid', - 'conv_projection', - 'mse_cost', - 'regression_cost', - 'classification_cost', - 'LayerOutput', - 'img_conv_layer', - 'img_pool_layer', - 'batch_norm_layer', - 'img_cmrnorm_layer', - 'addto_layer', - 'concat_layer', - 'seq_concat_layer', - 'lstm_step_layer', - 'recurrent_group', - 'memory', - 'StaticInput', - 'expand_layer', - 'scaling_layer', - 'scaling_projection', - 'power_layer', - 'interpolation_layer', - 'bilinear_interp_layer', - 'trans_layer', - 'rotate_layer', - 'sum_to_one_norm_layer', - 'row_l2_norm_layer', - 'get_output_layer', - 'LayerType', - 'context_projection', - 'beam_search', - 'maxid_layer', - 'GeneratedInput', - 'SubsequenceInput', - 'gru_step_layer', - 'gru_step_naive_layer', - 'recurrent_layer', - 'BaseGeneratedInput', - 'conv_operator', - 'conv_shift_layer', - 'tensor_layer', - 'selective_fc_layer', - 'sampling_id_layer', - 'slope_intercept_layer', - 'trans_full_matrix_projection', - 'linear_comb_layer', - 'convex_comb_layer', - 'ctc_layer', - 'warp_ctc_layer', - 'crf_layer', - 'crf_decoding_layer', - 'nce_layer', - 'cross_entropy_with_selfnorm', - 'cross_entropy', - 'multi_binary_label_cross_entropy', - 'sum_cost', - 'rank_cost', - 'lambda_cost', - 'huber_cost', - 'block_expand_layer', - 'maxout_layer', - 'out_prod_layer', - 'printer_layer', - 'print_layer', - 'priorbox_layer', - 'cross_channel_norm_layer', - 'multibox_loss_layer', - 'detection_output_layer', - 'spp_layer', - 'pad_layer', - 'eos_layer', - 'smooth_l1_cost', - 'layer_support', - 'multiplex_layer', - 'row_conv_layer', - 'dropout_layer', - 'prelu_layer', - 'gated_unit_layer', - 'crop_layer', - 'sub_nested_seq_layer', - 'clip_layer', - 'slice_projection', - 'kmax_sequence_score_layer', + 'full_matrix_projection', 'AggregateLevel', 'ExpandLevel', + 'identity_projection', 'dotmul_projection', 'dotmul_operator', + 'repeat_layer', 'seq_reshape_layer', 'table_projection', 'mixed_layer', + 'data_layer', 'embedding_layer', 'fc_layer', 'grumemory', 'pooling_layer', + 'lstmemory', 'last_seq', 'first_seq', 'cos_sim', 'hsigmoid', + 'conv_projection', 'mse_cost', 'regression_cost', 'classification_cost', + 'LayerOutput', 'img_conv_layer', 'img_pool_layer', 'batch_norm_layer', + 'img_cmrnorm_layer', 'addto_layer', 'concat_layer', 'seq_concat_layer', + 'lstm_step_layer', 'recurrent_group', 'memory', 'StaticInput', + 'expand_layer', 'scaling_layer', 'scaling_projection', 'power_layer', + 'interpolation_layer', 'bilinear_interp_layer', 'trans_layer', + 'rotate_layer', 'sum_to_one_norm_layer', 'row_l2_norm_layer', + 'get_output_layer', 'LayerType', 'context_projection', 'beam_search', + 'maxid_layer', 'GeneratedInput', 'SubsequenceInput', 'gru_step_layer', + 'gru_step_naive_layer', 'recurrent_layer', 'BaseGeneratedInput', + 'conv_operator', 'conv_shift_layer', 'tensor_layer', 'selective_fc_layer', + 'sampling_id_layer', 'slope_intercept_layer', + 'trans_full_matrix_projection', 'linear_comb_layer', 'convex_comb_layer', + 'ctc_layer', 'warp_ctc_layer', 'crf_layer', 'crf_decoding_layer', + 'nce_layer', 'cross_entropy_with_selfnorm', 'cross_entropy', + 'multi_binary_label_cross_entropy', 'sum_cost', 'rank_cost', 'lambda_cost', + 'huber_cost', 'block_expand_layer', 'maxout_layer', 'out_prod_layer', + 'printer_layer', 'print_layer', 'priorbox_layer', + 'cross_channel_norm_layer', 'multibox_loss_layer', 'detection_output_layer', + 'spp_layer', 'pad_layer', 'eos_layer', 'smooth_l1_cost', 'layer_support', + 'multiplex_layer', 'row_conv_layer', 'dropout_layer', 'prelu_layer', + 'gated_unit_layer', 'crop_layer', 'sub_nested_seq_layer', 'clip_layer', + 'slice_projection', 'kmax_sequence_score_layer', 'img_conv3d_layer' ] @@ -214,6 +140,9 @@ class LayerType(object): CRF_DECODING_LAYER = 'crf_decoding' NCE_LAYER = 'nce' + CONV3D_LAYER = 'conv3d' + DECONV3D_LAYER = 'deconv3d' + RANK_COST = 'rank-cost' LAMBDA_COST = 'lambda_cost' HUBER = 'huber' @@ -878,7 +807,8 @@ def mixed_layer(size=0, @layer_support() -def data_layer(name, size, height=None, width=None, layer_attr=None): +def data_layer(name, size, height=None, width=None, depth=None, + layer_attr=None): """ Define DataLayer For NeuralNetwork. @@ -907,6 +837,7 @@ def data_layer(name, size, height=None, width=None, layer_attr=None): size=size, height=height, width=width, + depth=depth, **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(name, LayerType.DATA, size=size) @@ -6210,3 +6141,182 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): return LayerOutput( name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) + + +@wrap_name_default("conv3d") +@wrap_param_attr_default() +@wrap_bias_attr_default() +@wrap_act_default(act=ReluActivation()) +@layer_support(DROPOUT) +def img_conv3d_layer(input, + filter_size, + num_filters, + name=None, + num_channels=None, + act=None, + groups=1, + stride=1, + padding=0, + bias_attr=None, + param_attr=None, + shared_biases=True, + layer_attr=None, + filter_size_y=None, + stride_y=None, + padding_y=None, + filter_size_z=None, + stride_z=None, + padding_z=None, + trans=False, + layer_type=None): + """ + + The example usage is: + + .. code-block:: python + + conv = img_conv3d_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + + :param name: Layer name. + :type name: basestring + :param input: Layer Input. + :type input: LayerOutput + :param filter_size: The x dimension of a filter kernel. Or input a tuple for + two image dimension. + :type filter_size: int|tuple|list + :param filter_size_y: The y dimension of a filter kernel. Since PaddlePaddle + currently supports rectangular filters, the filter's + shape will be (filter_size, filter_size_y). + :type filter_size_y: int|None + :param num_filters: Each filter group's number of filter + :param act: Activation type. Default is tanh + :type act: BaseActivation + :param groups: Group size of filters. + :type groups: int + :param stride: The x dimension of the stride. Or input a tuple for two image + dimension. + :type stride: int|tuple|list + :param stride_y: The y dimension of the stride. + :type stride_y: int + :param padding: The x dimension of the padding. Or input a tuple for two + image dimension + :type padding: int|tuple|list + :param padding_y: The y dimension of the padding. + :type padding_y: int + :param bias_attr: Convolution bias attribute. None means default bias. + False means no bias. + :type bias_attr: ParameterAttribute|False + :param num_channels: number of input channels. If None will be set + automatically from previous output. + :type num_channels: int + :param param_attr: Convolution param attribute. None means default attribute + :type param_attr: ParameterAttribute + :param shared_biases: Is biases will be shared between filters or not. + :type shared_biases: bool + :param layer_attr: Layer Extra Attribute. + :type layer_attr: ExtraLayerAttribute + :param trans: true if it is a convTransLayer, false if it is a convLayer + :type trans: bool + :param layer_type: specify the layer_type, default is None. If trans=True, + layer_type has to be "exconvt" or "cudnn_convt", + otherwise layer_type has to be either "exconv" or + "cudnn_conv" + :type layer_type: String + :return: LayerOutput object. + :rtype: LayerOutput + """ + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters + + if filter_size_y is None: + if isinstance(filter_size, collections.Sequence): + assert len(filter_size) == 2 + filter_size, filter_size_y = filter_size + else: + filter_size_y = filter_size + + if filter_size_z is None: + if isinstance(filter_size, collections.Sequence): + assert len(filter_size) == 2 + filter_size, filter_size_z = filter_size + else: + filter_size_z = filter_size + + if stride_y is None: + if isinstance(stride, collections.Sequence): + assert len(stride) == 2 + stride, stride_y = stride + else: + stride_y = stride + + if stride_z is None: + if isinstance(stride, collections.Sequence): + assert len(stride) == 2 + stride, stride_z = stride + else: + stride_z = stride + + if padding_y is None: + if isinstance(padding, collections.Sequence): + assert len(padding) == 2 + padding, padding_y = padding + else: + padding_y = padding + + if padding_z is None: + if isinstance(padding, collections.Sequence): + assert len(padding) == 2 + padding, padding_z = padding + else: + padding_z = padding + + if param_attr.attr.get('initial_smart'): + # special initial for conv layers. + init_w = (2.0 / (filter_size**2 * num_channels))**0.5 + param_attr.attr["initial_mean"] = 0.0 + param_attr.attr["initial_std"] = init_w + param_attr.attr["initial_strategy"] = 0 + param_attr.attr["initial_smart"] = False + + if layer_type: + if trans: + assert layer_type in ["deconv3d"] + lt = layer_type + else: + lt = LayerType.DECONV3D_LAYER if trans else LayerType.CONV3D_LAYER + + l = Layer( + name=name, + inputs=Input( + input.name, + conv=Conv3D( + filter_size=filter_size, + padding=padding, + stride=stride, + channels=num_channels, + groups=groups, + filter_size_y=filter_size_y, + padding_y=padding_y, + stride_y=stride_y, + filter_size_z=filter_size_z, + padding_z=padding_z, + stride_z=stride_z), + **param_attr.attr), + active_type=act.name, + num_filters=num_filters, + bias=ParamAttr.to_bias(bias_attr), + shared_biases=shared_biases, + type=lt, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name, + lt, + parents=[input], + activation=act, + num_filters=num_filters, + size=l.config.size) diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py old mode 100755 new mode 100644 index 34be203ee2545..28a71cf788f2b --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1406,7 +1406,7 @@ def inputs(layers, *args): if len(args) != 0: layers.extend(args) - Inputs(*[l.name for l in layers]) + Inputs(* [l.name for l in layers]) def outputs(layers, *args): @@ -1456,7 +1456,7 @@ def __dfs_travel__(layer, assert len(layers) > 0 if HasInputsSet(): # input already set - Outputs(*[l.name for l in layers]) + Outputs(* [l.name for l in layers]) return # just return outputs. if len(layers) != 1: diff --git a/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py b/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py new file mode 100644 index 0000000000000..da0d23d057d80 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py @@ -0,0 +1,98 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +num_channels = 3 +filter_size = 3 +filter_size_y = 3 +filter_size_z = 3 +stride = 2 +stride_y = 2 +stride_z = 2 +padding = 1 +padding_y = 1 +padding_z = 1 +groups = 1 + +data = data_layer( + name='data1', size=12096 * num_channels, height=48, width=42, depth=6) + +conv3d = img_conv3d_layer( + input=data, + name='conv3d_1', + num_filters=16, + num_channels=num_channels, + filter_size=filter_size, + filter_size_y=filter_size, + filter_size_z=filter_size, + stride=stride, + stride_y=stride_y, + stride_z=stride_z, + padding=padding, + padding_y=padding_y, + padding_z=padding_z, + groups=groups, + bias_attr=True, + shared_biases=True, + trans=False, + layer_type="conv3d", + act=LinearActivation()) + +deconv3d = img_conv3d_layer( + input=data, + name='deconv3d_1', + num_filters=16, + num_channels=num_channels, + filter_size=filter_size, + filter_size_y=filter_size, + filter_size_z=filter_size, + stride=stride, + stride_y=stride_y, + stride_z=stride_z, + padding=padding, + padding_y=padding_y, + padding_z=padding_z, + groups=groups, + bias_attr=True, + shared_biases=True, + trans=True, + layer_type="deconv3d", + act=LinearActivation()) + +data = data_layer(name="input", size=8 * 16 * 16) +conv1 = img_conv_layer( + input=data, + filter_size=1, + filter_size_y=1, + num_channels=8, + num_filters=16, + stride=1, + bias_attr=False, + act=ReluActivation(), + layer_type="exconv") +conv2 = img_conv_layer( + input=data, + filter_size=1, + filter_size_y=1, + num_channels=8, + num_filters=16, + stride=1, + bias_attr=False, + act=ReluActivation(), + layer_type="exconv") + +concat = concat_layer(input=[conv1, conv2]) + +conv = img_conv_layer( + input=data, + filter_size=1, + filter_size_y=1, + num_channels=8, + num_filters=16, + stride=1, + bias_attr=True, + act=LinearActivation(), + groups=2, + layer_type="exconv") + +outputs(concat, conv) diff --git a/python/paddle/trainer_config_helpers/tests/layers_test.py b/python/paddle/trainer_config_helpers/tests/layers_test.py index 05902ea293df5..44d1c1c9b2833 100644 --- a/python/paddle/trainer_config_helpers/tests/layers_test.py +++ b/python/paddle/trainer_config_helpers/tests/layers_test.py @@ -16,4 +16,6 @@ if __name__ == '__main__': parse_config_and_serialize( - 'trainer_config_helpers/tests/layers_test_config.py', '') + 'trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py', + '') +# layers_test_config.py From 2710584ff1d5d299361c1b4492d3368ccbdb0378 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 23 Aug 2017 22:05:50 +0800 Subject: [PATCH 15/20] fix above comments --- python/paddle/trainer/config_parser.py | 212 ++++++------------ .../paddle/trainer_config_helpers/layers.py | 76 ++----- .../configs/conv3d_deconv3d_test_config.py | 97 ++++---- 3 files changed, 130 insertions(+), 255 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 49b3c430e7050..c0843a7357837 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -901,20 +901,14 @@ def __init__(self, padding_z=None, stride_z=None): self.add_keys(locals()) - if filter_size_y is None: - self.filter_size_y = filter_size - if padding_y is None: - self.padding_y = padding - if stride_y is None: - self.stride_y = stride + self.filter_size_y = filter_size_y if filter_size_y else filter_size + self.filter_size_z = filter_size_z if filter_size_z else filter_size + self.padding_y = padding_y if padding_y else padding + self.padding_z = padding_z if padding_z else padding + self.stride_y = stride_y if stride_y else stride + self.stride_z = stride_z if stride_z else stride if output_x is not None: config_assert(output_x <= 0) - if filter_size_z is None: - self.filter_size_z = filter_size - if padding_z is None: - self.padding_z = padding - if stride_z is None: - self.stride_z = stride @config_class @@ -1206,10 +1200,10 @@ def get_img_size(input_layer_name, channels): def get_img3d_size(input_layer_name, channels): input = g_layer_map[input_layer_name] img_pixels = input.size / channels - img_size = input.width if input.width > 0 else int(img_pixels**0.5) - img_size_y = input.height if input.height > 0 else int(img_pixels / - img_size) - img_size_z = input.depth if input.depth > 1 else 1 + img_size = input.width + img_size_y = input.height + img_size_z = input.depth + config_assert( img_size * img_size_y * img_size_z == img_pixels, "Input layer %s: Incorrect input image size %d * %d * %d for input image pixels %d" @@ -2000,8 +1994,10 @@ class ConvLayer(ConvLayerBase): layer_type = 'cudnn_conv' -@config_layer('conv_3d') -class Conv3DLayerBase(LayerBase): +@config_layer('convt') +class ConvTransLayerBase(LayerBase): + layer_type = 'convt' + def __init__(self, name, inputs=[], @@ -2009,7 +2005,7 @@ def __init__(self, num_filters=None, shared_biases=False, **xargs): - super(Conv3DLayerBase, self).__init__( + super(ConvTransLayerBase, self).__init__( name, self.layer_type, 0, inputs=inputs, **xargs) if num_filters is not None: @@ -2018,12 +2014,17 @@ def __init__(self, use_gpu = int(g_command_config_args.get("use_gpu", 0)) parallel_nn = int(g_command_config_args.get("parallel_nn", 0)) - # Automatically select cudnn_type for GPU and exconv for CPU - # if set type=conv, but still reserve the way user specify - # exconv or cudnn_conv manually. - if self.layer_type == "cudnn_conv3d": - config_assert(use_gpu, "cudnn_conv3d only support GPU") + # Automatically select cudnn_type for GPU and exconvt for CPU + # if set type=exconvt, but still reserve the way user specify + # exconvt or cudnn_convt manually. + if self.layer_type == "cudnn_convt": + config_assert(use_gpu, "cudnn_convt only support GPU") + if (use_gpu == 1 and self.layer_type != "exconvt" and + (parallel_nn == 0 or self.config.device > -1)): + self.layer_type = "cudnn_convt" + else: + self.layer_type = "exconvt" # need to specify layer in config self.config.type = self.layer_type @@ -2032,15 +2033,17 @@ def __init__(self, for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) + parse_conv( + self.inputs[input_index].conv, + input_layer.name, + self.config.inputs[input_index].conv_conf, + num_filters, + trans=True) conv_conf = self.config.inputs[input_index].conv_conf - parse_conv3d( - self.inputs[input_index].conv, input_layer.name, conv_conf, - num_filters - ) # for z-axis pad:0, strid:1, filter_size:1, img_size:1 psize = self.calc_parameter_size(conv_conf) self.create_input_parameter(input_index, psize) - self.set_cnn_layer(name, conv_conf.output_z, conv_conf.output_y, - conv_conf.output_x, self.config.num_filters) + self.set_cnn_layer(name, conv_conf.img_size_y, conv_conf.img_size, + self.config.num_filters) psize = self.config.size if shared_biases: @@ -2048,62 +2051,42 @@ def __init__(self, self.create_bias_parameter(bias, psize, [psize, 1]) def calc_parameter_size(self, conv_conf): - return self.config.num_filters * conv_conf.filter_channels \ - * (conv_conf.filter_size * conv_conf.filter_size_y \ - * conv_conf.filter_size_z) + return conv_conf.channels * conv_conf.filter_channels \ + * (conv_conf.filter_size * conv_conf.filter_size_y) - def set_layer_height_width(self, depth, height, width): - self.config.depth = depth - self.config.height = height - self.config.width = width - def set_cnn_layer(self, - input_layer_name, - depth, - height, - width, - channels, - is_print=True): - size = depth * height * width * channels - self.set_layer_size(size) - self.set_layer_height_width(depth, height, width) - if is_print: - print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" % - (input_layer_name, channels, depth, height, width, size)) +@config_layer('exconvt') +class ConvTransLayer(ConvTransLayerBase): + layer_type = 'exconvt' -@config_layer('conv3d') -class Conv3DLayer(Conv3DLayerBase): - layer_type = 'conv3d' +@config_layer('cudnn_convt') +class ConvTransLayer(ConvTransLayerBase): + layer_type = 'cudnn_convt' -@config_layer('convt_3d') -class Conv3DTransLayerBase(LayerBase): +@config_layer('conv_3d') +class Conv3DLayerBase(LayerBase): def __init__(self, name, inputs=[], bias=True, num_filters=None, - shared_biases=False, + shared_biases=True, **xargs): - super(Conv3DTransLayerBase, self).__init__( + super(Conv3DLayerBase, self).__init__( name, self.layer_type, 0, inputs=inputs, **xargs) if num_filters is not None: self.config.num_filters = num_filters - use_gpu = int(g_command_config_args.get("use_gpu", 0)) - parallel_nn = int(g_command_config_args.get("parallel_nn", 0)) - - # Automatically select cudnn_type for GPU and exconv for CPU - # if set type=conv, but still reserve the way user specify - # exconv or cudnn_conv manually. - if self.layer_type == "cudnn_deconv3d": - config_assert(use_gpu, "cudnn_conv3d only support GPU") - # need to specify layer in config self.config.type = self.layer_type + trans = False + if self.config.type == "deconv3d": + trans = True + if shared_biases is not None: self.config.shared_biases = shared_biases @@ -2115,12 +2098,17 @@ def __init__(self, input_layer.name, conv_conf, num_filters, - trans=True + trans=trans ) # for z-axis pad:0, strid:1, filter_size:1, img_size:1 psize = self.calc_parameter_size(conv_conf) self.create_input_parameter(input_index, psize) - self.set_cnn_layer(name, conv_conf.img_size_z, conv_conf.img_size_y, - conv_conf.img_size, self.config.num_filters) + if trans: + self.set_cnn_layer(name, conv_conf.img_size_z, + conv_conf.img_size_y, conv_conf.img_size, + self.config.num_filters) + else: + self.set_cnn_layer(name, conv_conf.output_z, conv_conf.output_y, + conv_conf.output_x, self.config.num_filters) psize = self.config.size if shared_biases: @@ -2132,11 +2120,6 @@ def calc_parameter_size(self, conv_conf): * (conv_conf.filter_size * conv_conf.filter_size_y \ * conv_conf.filter_size_z) - def set_layer_height_width(self, depth, height, width): - self.config.depth = depth - self.config.height = height - self.config.width = width - def set_cnn_layer(self, input_layer_name, depth, @@ -2146,86 +2129,21 @@ def set_cnn_layer(self, is_print=True): size = depth * height * width * channels self.set_layer_size(size) - self.set_layer_height_width(depth, height, width) + self.set_layer_height_width(height, width) + self.set_layer_depth(depth) if is_print: print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" % (input_layer_name, channels, depth, height, width, size)) -@config_layer('deconv3d') -class DeConv3DLayer(Conv3DTransLayerBase): - layer_type = 'deconv3d' - - -@config_layer('convt') -class ConvTransLayerBase(LayerBase): - layer_type = 'convt' - - def __init__(self, - name, - inputs=[], - bias=True, - num_filters=None, - shared_biases=False, - **xargs): - super(ConvTransLayerBase, self).__init__( - name, self.layer_type, 0, inputs=inputs, **xargs) - - if num_filters is not None: - self.config.num_filters = num_filters - - use_gpu = int(g_command_config_args.get("use_gpu", 0)) - parallel_nn = int(g_command_config_args.get("parallel_nn", 0)) - - # Automatically select cudnn_type for GPU and exconvt for CPU - # if set type=exconvt, but still reserve the way user specify - # exconvt or cudnn_convt manually. - if self.layer_type == "cudnn_convt": - config_assert(use_gpu, "cudnn_convt only support GPU") - - if (use_gpu == 1 and self.layer_type != "exconvt" and - (parallel_nn == 0 or self.config.device > -1)): - self.layer_type = "cudnn_convt" - else: - self.layer_type = "exconvt" - # need to specify layer in config - self.config.type = self.layer_type - - if shared_biases is not None: - self.config.shared_biases = shared_biases - - for input_index in xrange(len(self.inputs)): - input_layer = self.get_input_layer(input_index) - parse_conv( - self.inputs[input_index].conv, - input_layer.name, - self.config.inputs[input_index].conv_conf, - num_filters, - trans=True) - conv_conf = self.config.inputs[input_index].conv_conf - psize = self.calc_parameter_size(conv_conf) - self.create_input_parameter(input_index, psize) - self.set_cnn_layer(name, conv_conf.img_size_y, conv_conf.img_size, - self.config.num_filters) - - psize = self.config.size - if shared_biases: - psize = self.config.num_filters - self.create_bias_parameter(bias, psize, [psize, 1]) - - def calc_parameter_size(self, conv_conf): - return conv_conf.channels * conv_conf.filter_channels \ - * (conv_conf.filter_size * conv_conf.filter_size_y) - - -@config_layer('exconvt') -class ConvTransLayer(ConvTransLayerBase): - layer_type = 'exconvt' +@config_layer('conv3d') +class Conv3DLayer(Conv3DLayerBase): + layer_type = 'conv3d' -@config_layer('cudnn_convt') -class ConvTransLayer(ConvTransLayerBase): - layer_type = 'cudnn_convt' +@config_layer('deconv3d') +class Conv3DLayer(Conv3DLayerBase): + layer_type = 'deconv3d' @config_layer('norm') diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 6953f134c5dac..e3ae81459f359 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -6161,12 +6161,6 @@ def img_conv3d_layer(input, param_attr=None, shared_biases=True, layer_attr=None, - filter_size_y=None, - stride_y=None, - padding_y=None, - filter_size_z=None, - stride_z=None, - padding_z=None, trans=False, layer_type=None): """ @@ -6175,7 +6169,7 @@ def img_conv3d_layer(input, .. code-block:: python - conv = img_conv3d_layer(input=data, filter_size=1, filter_size_y=1, + conv = img_conv3d_layer(input=data, filter_size=1, num_channels=8, num_filters=16, stride=1, bias_attr=False, @@ -6185,13 +6179,8 @@ def img_conv3d_layer(input, :type name: basestring :param input: Layer Input. :type input: LayerOutput - :param filter_size: The x dimension of a filter kernel. Or input a tuple for - two image dimension. + :param filter_size: The x dimension of a filter kernel. Or input a list. :type filter_size: int|tuple|list - :param filter_size_y: The y dimension of a filter kernel. Since PaddlePaddle - currently supports rectangular filters, the filter's - shape will be (filter_size, filter_size_y). - :type filter_size_y: int|None :param num_filters: Each filter group's number of filter :param act: Activation type. Default is tanh :type act: BaseActivation @@ -6200,13 +6189,9 @@ def img_conv3d_layer(input, :param stride: The x dimension of the stride. Or input a tuple for two image dimension. :type stride: int|tuple|list - :param stride_y: The y dimension of the stride. - :type stride_y: int :param padding: The x dimension of the padding. Or input a tuple for two image dimension :type padding: int|tuple|list - :param padding_y: The y dimension of the padding. - :type padding_y: int :param bias_attr: Convolution bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|False @@ -6233,47 +6218,26 @@ def img_conv3d_layer(input, assert input.num_filters is not None num_channels = input.num_filters - if filter_size_y is None: - if isinstance(filter_size, collections.Sequence): - assert len(filter_size) == 2 - filter_size, filter_size_y = filter_size - else: - filter_size_y = filter_size - - if filter_size_z is None: - if isinstance(filter_size, collections.Sequence): - assert len(filter_size) == 2 - filter_size, filter_size_z = filter_size - else: - filter_size_z = filter_size - - if stride_y is None: - if isinstance(stride, collections.Sequence): - assert len(stride) == 2 - stride, stride_y = stride - else: - stride_y = stride - - if stride_z is None: - if isinstance(stride, collections.Sequence): - assert len(stride) == 2 - stride, stride_z = stride - else: - stride_z = stride + if isinstance(filter_size, collections.Sequence): + assert len(filter_size) == 3 + filter_size, filter_size_y, filter_size_z = filter_size + else: + filter_size_y = filter_size + filter_size_z = filter_size - if padding_y is None: - if isinstance(padding, collections.Sequence): - assert len(padding) == 2 - padding, padding_y = padding - else: - padding_y = padding + if isinstance(stride, collections.Sequence): + assert len(stride) == 3 + stride, stride_y, stride_z = stride + else: + stride_y = stride + stride_z = stride - if padding_z is None: - if isinstance(padding, collections.Sequence): - assert len(padding) == 2 - padding, padding_z = padding - else: - padding_z = padding + if isinstance(padding, collections.Sequence): + assert len(padding) == 3 + padding, padding_y, padding_z = padding + else: + padding_y = padding + padding_z = padding if param_attr.attr.get('initial_smart'): # special initial for conv layers. diff --git a/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py b/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py index da0d23d057d80..15f7c1d271fce 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py +++ b/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py @@ -14,23 +14,44 @@ padding_z = 1 groups = 1 -data = data_layer( - name='data1', size=12096 * num_channels, height=48, width=42, depth=6) +data1 = data_layer(name='data1', size=2016 * num_channels, height=48, width=42) -conv3d = img_conv3d_layer( +img_conv_layer( + input=data1, + filter_size=filter_size, + num_channels=num_channels, + num_filters=16, + stride=stride, + padding=padding, + act=LinearActivation(), + bias_attr=False) + +data = data_layer( + name='data', size=12096 * num_channels, height=48, width=42, depth=6) +# first +conv3d_1 = img_conv3d_layer( input=data, name='conv3d_1', num_filters=16, num_channels=num_channels, filter_size=filter_size, - filter_size_y=filter_size, - filter_size_z=filter_size, stride=stride, - stride_y=stride_y, - stride_z=stride_z, padding=padding, - padding_y=padding_y, - padding_z=padding_z, + groups=groups, + bias_attr=True, + shared_biases=True, + trans=False, + layer_type="conv3d", + act=LinearActivation()) +# second +conv3d_2 = img_conv3d_layer( + input=data, + name='conv3d_2', + num_filters=16, + num_channels=num_channels, + filter_size=[filter_size, filter_size_y, filter_size_z], + stride=[stride, stride_y, stride_z], + padding=[padding, padding_y, padding_z], groups=groups, bias_attr=True, shared_biases=True, @@ -38,61 +59,33 @@ layer_type="conv3d", act=LinearActivation()) -deconv3d = img_conv3d_layer( +# first +deconv3d_1 = img_conv3d_layer( input=data, name='deconv3d_1', num_filters=16, num_channels=num_channels, filter_size=filter_size, - filter_size_y=filter_size, - filter_size_z=filter_size, stride=stride, - stride_y=stride_y, - stride_z=stride_z, padding=padding, - padding_y=padding_y, - padding_z=padding_z, groups=groups, bias_attr=True, shared_biases=True, - trans=True, + trans=False, layer_type="deconv3d", act=LinearActivation()) - -data = data_layer(name="input", size=8 * 16 * 16) -conv1 = img_conv_layer( - input=data, - filter_size=1, - filter_size_y=1, - num_channels=8, - num_filters=16, - stride=1, - bias_attr=False, - act=ReluActivation(), - layer_type="exconv") -conv2 = img_conv_layer( - input=data, - filter_size=1, - filter_size_y=1, - num_channels=8, - num_filters=16, - stride=1, - bias_attr=False, - act=ReluActivation(), - layer_type="exconv") - -concat = concat_layer(input=[conv1, conv2]) - -conv = img_conv_layer( +# second +deconv3d_2 = img_conv3d_layer( input=data, - filter_size=1, - filter_size_y=1, - num_channels=8, + name='deconv3d_2', num_filters=16, - stride=1, + num_channels=num_channels, + filter_size=[filter_size, filter_size_y, filter_size_z], + stride=[stride, stride_y, stride_z], + padding=[padding, padding_y, padding_z], + groups=groups, bias_attr=True, - act=LinearActivation(), - groups=2, - layer_type="exconv") - -outputs(concat, conv) + shared_biases=True, + trans=False, + layer_type="deconv3d", + act=LinearActivation()) From 6053f7e36b19a06da14c970a1e4f25a02d1dbcaf Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 28 Aug 2017 18:10:44 +0800 Subject: [PATCH 16/20] fix previous comments(c++) --- paddle/cuda/include/hl_matrix.h | 2 +- paddle/gserver/layers/Conv3DLayer.cpp | 6 ------ paddle/gserver/layers/DeConv3DLayer.cpp | 6 ------ 3 files changed, 1 insertion(+), 13 deletions(-) diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index a37921b7493e3..c7f2510997219 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -241,7 +241,7 @@ extern void hl_matrix_rotate( * @param[in] paddingD padding in the depth. * @param[in] paddingH padding in the height. * @param[in] paddingW padding in the width. - * @param[out] matDst output matrix. + * @param[out] dataDst output matrix. * */ extern void hl_matrix_vol2Col(const real* dataSrc, diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp index db907bbab1c28..7cc9937cce37c 100644 --- a/paddle/gserver/layers/Conv3DLayer.cpp +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -53,18 +53,12 @@ bool Conv3DLayer::init(const LayerMap &layerMap, size_t Conv3DLayer::getSize() { CHECK_NE(inputLayers_.size(), 0UL); - // imgSizeH_.clear(); - // imgSizeW_.clear(); - // imgSizeD_.clear(); outputH_.clear(); outputW_.clear(); outputD_.clear(); N_.clear(); size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); ++i) { - // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); - // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); - // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); outputW_.push_back(outputSize( imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); outputH_.push_back(outputSize( diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp index b18c06e36c897..7d5c772c89d26 100644 --- a/paddle/gserver/layers/DeConv3DLayer.cpp +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -53,9 +53,6 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, size_t DeConv3DLayer::getSize() { CHECK_NE(inputLayers_.size(), 0UL); - // imgSizeH_.clear(); - // imgSizeW_.clear(); - // imgSizeD_.clear(); outputH_.clear(); outputW_.clear(); outputD_.clear(); @@ -63,9 +60,6 @@ size_t DeConv3DLayer::getSize() { NOut_.clear(); size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); ++i) { - // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); - // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); - // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); outputW_.push_back( imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); outputH_.push_back(imageSize( From 34f4f763f9cf52d6c6326613ed839d00ac7c6eb0 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 30 Aug 2017 10:19:08 +0800 Subject: [PATCH 17/20] Update networks.py --- python/paddle/trainer_config_helpers/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 28a71cf788f2b..34be203ee2545 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1406,7 +1406,7 @@ def inputs(layers, *args): if len(args) != 0: layers.extend(args) - Inputs(* [l.name for l in layers]) + Inputs(*[l.name for l in layers]) def outputs(layers, *args): @@ -1456,7 +1456,7 @@ def __dfs_travel__(layer, assert len(layers) > 0 if HasInputsSet(): # input already set - Outputs(* [l.name for l in layers]) + Outputs(*[l.name for l in layers]) return # just return outputs. if len(layers) != 1: From 2ae37a4ea2f4b02ffe6b773590ed05c77675e6f5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 31 Aug 2017 00:28:01 +0800 Subject: [PATCH 18/20] fix data_layer for 3D data --- python/paddle/trainer_config_helpers/layers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 42bf1c19d1097..2aa86850d12e8 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -926,16 +926,18 @@ def data_layer(name, size, height=None, width=None, depth=None, type=LayerType.DATA, name=name, size=size, + depth=depth, height=height, width=width, - depth=depth, **ExtraLayerAttribute.to_kwargs(layer_attr)) + if depth is None: + depth = 1 num_filters = None if height is not None and width is not None: - num_filters = size / (width * height) - assert num_filters * width * height == size, \ - "size=%s width=%s height=%s" % (size, width, height) + num_filters = size / (width * height * depth) + assert num_filters * width * height*depth == size, \ + "size=%s width=%s height=%s depth=%s" % (size, width, height, depth) return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters) From 2e97045c2354ea8a6ae39ee17e93098a2ec930d4 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 31 Aug 2017 14:10:40 +0800 Subject: [PATCH 19/20] fix layers_test.py --- .../tests/configs/file_list.sh | 2 +- ...3d_test_config.py => test_conv3d_layer.py} | 44 +--------------- .../tests/configs/test_deconv3d_layer.py | 50 +++++++++++++++++++ .../tests/layers_test.py | 3 +- 4 files changed, 53 insertions(+), 46 deletions(-) rename python/paddle/trainer_config_helpers/tests/configs/{conv3d_deconv3d_test_config.py => test_conv3d_layer.py} (51%) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_deconv3d_layer.py diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 1ca5c8a07ebb7..729e8e67c2431 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -9,6 +9,6 @@ test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer -test_seq_slice_layer) +test_seq_slice_layer test_conv3d_layer test_deconv3d_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py b/python/paddle/trainer_config_helpers/tests/configs/test_conv3d_layer.py similarity index 51% rename from python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py rename to python/paddle/trainer_config_helpers/tests/configs/test_conv3d_layer.py index 15f7c1d271fce..aa0a2c0d5fe19 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_conv3d_layer.py @@ -14,18 +14,6 @@ padding_z = 1 groups = 1 -data1 = data_layer(name='data1', size=2016 * num_channels, height=48, width=42) - -img_conv_layer( - input=data1, - filter_size=filter_size, - num_channels=num_channels, - num_filters=16, - stride=stride, - padding=padding, - act=LinearActivation(), - bias_attr=False) - data = data_layer( name='data', size=12096 * num_channels, height=48, width=42, depth=6) # first @@ -58,34 +46,4 @@ trans=False, layer_type="conv3d", act=LinearActivation()) - -# first -deconv3d_1 = img_conv3d_layer( - input=data, - name='deconv3d_1', - num_filters=16, - num_channels=num_channels, - filter_size=filter_size, - stride=stride, - padding=padding, - groups=groups, - bias_attr=True, - shared_biases=True, - trans=False, - layer_type="deconv3d", - act=LinearActivation()) -# second -deconv3d_2 = img_conv3d_layer( - input=data, - name='deconv3d_2', - num_filters=16, - num_channels=num_channels, - filter_size=[filter_size, filter_size_y, filter_size_z], - stride=[stride, stride_y, stride_z], - padding=[padding, padding_y, padding_z], - groups=groups, - bias_attr=True, - shared_biases=True, - trans=False, - layer_type="deconv3d", - act=LinearActivation()) +outputs(conv3d_2) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_deconv3d_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_deconv3d_layer.py new file mode 100644 index 0000000000000..a113279fc17b4 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_deconv3d_layer.py @@ -0,0 +1,50 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +num_channels = 3 +filter_size = 3 +filter_size_y = 3 +filter_size_z = 3 +stride = 2 +stride_y = 2 +stride_z = 2 +padding = 1 +padding_y = 1 +padding_z = 1 +groups = 1 + +data = data_layer( + name='data', size=12096 * num_channels, height=48, width=42, depth=6) + +# first +deconv3d_1 = img_conv3d_layer( + input=data, + name='deconv3d_1', + num_filters=16, + num_channels=num_channels, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + bias_attr=True, + shared_biases=True, + trans=True, + layer_type="deconv3d", + act=LinearActivation()) +# second +deconv3d_2 = img_conv3d_layer( + input=data, + name='deconv3d_2', + num_filters=16, + num_channels=num_channels, + filter_size=[filter_size, filter_size_y, filter_size_z], + stride=[stride, stride_y, stride_z], + padding=[padding, padding_y, padding_z], + groups=groups, + bias_attr=True, + shared_biases=True, + trans=True, + layer_type="deconv3d", + act=LinearActivation()) +outputs(deconv3d_2) diff --git a/python/paddle/trainer_config_helpers/tests/layers_test.py b/python/paddle/trainer_config_helpers/tests/layers_test.py index 44d1c1c9b2833..b3dd8f8fc7847 100644 --- a/python/paddle/trainer_config_helpers/tests/layers_test.py +++ b/python/paddle/trainer_config_helpers/tests/layers_test.py @@ -16,6 +16,5 @@ if __name__ == '__main__': parse_config_and_serialize( - 'trainer_config_helpers/tests/configs/conv3d_deconv3d_test_config.py', - '') + 'trainer_config_helpers/tests/layers_test_config.py', '') # layers_test_config.py From a4e1e127f3aa5a64cc777deab31a410874fd7ff7 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 31 Aug 2017 16:33:01 +0800 Subject: [PATCH 20/20] Add test_conv3d_layer.protostr,test_deconv3d_layer.protostr --- .../protostr/test_conv3d_layer.protostr | 132 ++++++++++++++++++ .../protostr/test_deconv3d_layer.protostr | 132 ++++++++++++++++++ 2 files changed, 264 insertions(+) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/protostr/test_conv3d_layer.protostr create mode 100644 python/paddle/trainer_config_helpers/tests/configs/protostr/test_deconv3d_layer.protostr diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_conv3d_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_conv3d_layer.protostr new file mode 100644 index 0000000000000..9fe2bc29d3cd0 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_conv3d_layer.protostr @@ -0,0 +1,132 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 36288 + active_type: "" + height: 48 + width: 42 + depth: 6 +} +layers { + name: "conv3d_1" + type: "conv3d" + size: 24192 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_conv3d_1.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 3 + output_x: 21 + img_size: 42 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 24 + img_size_y: 48 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 3 + img_size_z: 6 + } + } + bias_parameter_name: "_conv3d_1.wbias" + num_filters: 16 + shared_biases: true + height: 24 + width: 21 + depth: 3 +} +layers { + name: "conv3d_2" + type: "conv3d" + size: 24192 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_conv3d_2.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 3 + output_x: 21 + img_size: 42 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 24 + img_size_y: 48 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 3 + img_size_z: 6 + } + } + bias_parameter_name: "_conv3d_2.wbias" + num_filters: 16 + shared_biases: true + height: 24 + width: 21 + depth: 3 +} +parameters { + name: "_conv3d_1.w0" + size: 1296 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_conv3d_1.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_conv3d_2.w0" + size: 1296 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_conv3d_2.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "conv3d_2" +sub_models { + name: "root" + layer_names: "data" + layer_names: "conv3d_1" + layer_names: "conv3d_2" + input_layer_names: "data" + output_layer_names: "conv3d_2" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_deconv3d_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_deconv3d_layer.protostr new file mode 100644 index 0000000000000..7bf409731cbf8 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_deconv3d_layer.protostr @@ -0,0 +1,132 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 36288 + active_type: "" + height: 48 + width: 42 + depth: 6 +} +layers { + name: "deconv3d_1" + type: "deconv3d" + size: 1387760 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_deconv3d_1.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 16 + output_x: 42 + img_size: 83 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 48 + img_size_y: 95 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 6 + img_size_z: 11 + } + } + bias_parameter_name: "_deconv3d_1.wbias" + num_filters: 16 + shared_biases: true + height: 95 + width: 83 + depth: 11 +} +layers { + name: "deconv3d_2" + type: "deconv3d" + size: 1387760 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "_deconv3d_2.w0" + conv_conf { + filter_size: 3 + channels: 3 + stride: 2 + padding: 1 + groups: 1 + filter_channels: 16 + output_x: 42 + img_size: 83 + caffe_mode: true + filter_size_y: 3 + padding_y: 1 + stride_y: 2 + output_y: 48 + img_size_y: 95 + filter_size_z: 3 + padding_z: 1 + stride_z: 2 + output_z: 6 + img_size_z: 11 + } + } + bias_parameter_name: "_deconv3d_2.wbias" + num_filters: 16 + shared_biases: true + height: 95 + width: 83 + depth: 11 +} +parameters { + name: "_deconv3d_1.w0" + size: 6912 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_deconv3d_1.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_deconv3d_2.w0" + size: 6912 + initial_mean: 0.0 + initial_std: 0.272165526976 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "_deconv3d_2.wbias" + size: 16 + initial_mean: 0.0 + initial_std: 0.0 + dims: 16 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "deconv3d_2" +sub_models { + name: "root" + layer_names: "data" + layer_names: "deconv3d_1" + layer_names: "deconv3d_2" + input_layer_names: "data" + output_layer_names: "deconv3d_2" + is_recurrent_layer_group: false +} +