diff --git a/paddle/gserver/layers/CudnnConvLayer.cpp b/paddle/gserver/layers/CudnnConvLayer.cpp index a74e6ba38dfc6..0f932f960f6ba 100644 --- a/paddle/gserver/layers/CudnnConvLayer.cpp +++ b/paddle/gserver/layers/CudnnConvLayer.cpp @@ -85,6 +85,7 @@ bool CudnnConvLayer::init(const LayerMap &layerMap, biasOffset_ = numFilters_ / groups_[0]; } + batchNum_ = 0; isSelectAlgo_ = false; return true; } @@ -132,6 +133,11 @@ void CudnnConvLayer::reshape(int batchSize) { getOutput().setFrameHeight(outputH_); getOutput().setFrameWidth(outputW_); + // if the batchSize remains the same, set isSelectAlgo_ true. + // Otherwise, set isSelectAlgo_ false and select algo again. + isSelectAlgo_ = (batchSize == batchNum_); + batchNum_ = batchSize; + size_t maxWorkSpace = 0; for (size_t i = 0; i < inputLayers_.size(); i++) { CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(), @@ -160,6 +166,10 @@ void CudnnConvLayer::reshape(int batchSize) { maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]); maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]); + + VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i] + << " / " << bwdDataAlgo_[i] + << " / " << bwdFilterAlgo_[i]; } } diff --git a/paddle/gserver/layers/CudnnConvLayer.h b/paddle/gserver/layers/CudnnConvLayer.h index 2c72ba885ed10..a6dadba10daa4 100644 --- a/paddle/gserver/layers/CudnnConvLayer.h +++ b/paddle/gserver/layers/CudnnConvLayer.h @@ -87,6 +87,10 @@ class CudnnConvLayer : public ConvBaseLayer { /// Is or not select conv algorihtm. bool isSelectAlgo_; + /// batchNum is used to record batch size. If the batch size is changed, + /// the selection algorithm will be called. + int batchNum_; + public: explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}