From 4248b1fbd9cee672ac3fea9237844ca1136922b4 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Thu, 2 Jun 2016 10:50:43 -0700 Subject: [PATCH] Use cuDNN routine FindEx to find best algorithm. FindEx is more stable than Get (heuristic-based) because it runs all the available algorithms and sorts them according to their speed. Both Get and FindEx are supported now and can be specified through the definition of each layer in prototxt. In Reshape, check whether shape of bottom and convolution descriptors have changed. In caffe time, do multiple (instead of one) fwd/bwd pass in the initilization phase. This is crucial because FindEx is executed in the first iterations and it takes quite a long time. --- include/caffe/layers/cudnn_conv_layer.hpp | 21 +- src/caffe/layers/cudnn_conv_layer.cpp | 363 +++++++++++++++++++--- src/caffe/layers/cudnn_conv_layer.cu | 11 +- src/caffe/proto/caffe.proto | 6 + tools/caffe.cpp | 33 +- 5 files changed, 377 insertions(+), 57 deletions(-) diff --git a/include/caffe/layers/cudnn_conv_layer.hpp b/include/caffe/layers/cudnn_conv_layer.hpp index 270595483ea..51ec6a38de0 100644 --- a/include/caffe/layers/cudnn_conv_layer.hpp +++ b/include/caffe/layers/cudnn_conv_layer.hpp @@ -34,7 +34,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { public: explicit CuDNNConvolutionLayer(const LayerParameter& param) : ConvolutionLayer(param), handles_setup_(false), - backward_passed_ctr_(0) {} + use_algo_seeker_(true), use_modest_workspace_(true) {} virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); virtual void Reshape(const vector*>& bottom, @@ -65,7 +65,24 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { size_t *workspace_bwd_data_sizes_; size_t *workspace_bwd_filter_sizes_; GPUMemory::Workspace workspace; - int backward_passed_ctr_; + + private: + bool use_algo_seeker_; + bool use_modest_workspace_; + void FindExConvAlgo(const vector*>& bottom, + const vector*>& top, + const size_t workspace_bytes); + void GetConvAlgo(const vector*>& bottom, + const vector*>& top, + const size_t workspace_bytes); + + vector cached_bottom_descs_; + vector cached_conv_descs_; + bool IsBottomDescChanged(const vector*>& bottom); + bool IsConvDescChanged(const vector*>& bottom); + + bool use_reshape_; + bool initialized_cached_descs_; }; #endif diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 061f3998c00..6f149fc6237 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -28,12 +28,13 @@ void CuDNNConvolutionLayer::LayerSetUp( workspace_bwd_filter_sizes_ = new size_t[bottom.size()]; workspace_bwd_data_sizes_ = new size_t[bottom.size()]; + // Initializing algorithms and workspaces + // Do not rely on initialized algorithms (Reshape will set algorithms + // with correct values in the first iteration). for (size_t i = 0; i < bottom.size(); ++i) { - // initialize all to default algorithms fwd_algo_[i] = (cudnnConvolutionFwdAlgo_t)0; bwd_filter_algo_[i] = (cudnnConvolutionBwdFilterAlgo_t)0; bwd_data_algo_[i] = (cudnnConvolutionBwdDataAlgo_t)0; - // default algorithms don't require workspace workspace_fwd_sizes_[i] = 0; workspace_bwd_data_sizes_[i] = 0; workspace_bwd_filter_sizes_[i] = 0; @@ -57,12 +58,22 @@ void CuDNNConvolutionLayer::LayerSetUp( cudnnTensorDescriptor_t bottom_desc; cudnn::createTensor4dDesc(&bottom_desc); bottom_descs_.push_back(bottom_desc); + cudnnTensorDescriptor_t top_desc; cudnn::createTensor4dDesc(&top_desc); top_descs_.push_back(top_desc); + cudnnConvolutionDescriptor_t conv_desc; cudnn::createConvolutionDesc(&conv_desc); conv_descs_.push_back(conv_desc); + + cudnnTensorDescriptor_t cached_bottom_desc; + cudnn::createTensor4dDesc(&cached_bottom_desc); + cached_bottom_descs_.push_back(cached_bottom_desc); + + cudnnConvolutionDescriptor_t cached_conv_desc; + cudnn::createConvolutionDesc(&cached_conv_desc); + cached_conv_descs_.push_back(cached_conv_desc); } // Tensor descriptor for bias. @@ -71,17 +82,51 @@ void CuDNNConvolutionLayer::LayerSetUp( } handles_setup_ = true; - backward_passed_ctr_ = 0; + // When true, Reshape asks cuDNN (either Get ot FindEx) for the best algorithm + use_algo_seeker_ = true; + // When true, a small amount of workspace is allowed for algorithms + use_modest_workspace_ = true; + // When true, Reshape sets descriptors, algorithms, workspaces. + use_reshape_ = true; + // When true, cached bottom and conv descriptors need to be set. + initialized_cached_descs_ = false; } template void CuDNNConvolutionLayer::Reshape( const vector*>& bottom, const vector*>& top) { + // Check whether cached descriptors have been initialized. + if (initialized_cached_descs_) { + // Check whether bottom and conv descriptors have changed, + // which then requires a new reshape and set algo. + if ((IsBottomDescChanged(bottom)) || + (IsConvDescChanged(bottom))) { + use_reshape_ = true; + // When reshape, algorithms need to be set again. + use_algo_seeker_ = true; + use_modest_workspace_ = true; + } else { + // When no reshape is needed, setting algo may be still needed + // (for example, if we are at iteration 1). + // If we want to set algos, we have to use reshape in + // current implementation. + use_reshape_ = use_algo_seeker_; + } + } else { + // If cached descriptors are not initialized yet, need to + // do reshape which also initializes cached descriptors. + use_reshape_ = true; + } + if (!use_reshape_) { + return; + } + ConvolutionLayer::Reshape(bottom, top); CHECK_EQ(2, this->num_spatial_axes_) << "CuDNNConvolution input must have 2 spatial axes " << "(e.g., height and width). " << "Use 'engine: CAFFE' for general ND convolution."; + bottom_offset_ = this->bottom_dim_ / this->group_; top_offset_ = this->top_dim_ / this->group_; const int height = bottom[0]->shape(this->channel_axis_ + 1); @@ -95,11 +140,7 @@ void CuDNNConvolutionLayer::Reshape( const int stride_h = stride_data[0]; const int stride_w = stride_data[1]; - // Specify workspace limit for kernels directly until we have a - // planning strategy and a rewrite of Caffe's GPU memory mangagement - size_t workspace_limit_bytes, total_memory; - GPUMemory::GetInfo(&workspace_limit_bytes, &total_memory); - + // Set cuDNN tensor and convolution descriptors for (int i = 0; i < bottom.size(); i++) { cudnn::setTensor4dDesc(&bottom_descs_[i], this->num_, @@ -111,48 +152,77 @@ void CuDNNConvolutionLayer::Reshape( this->num_output_ / this->group_, height_out, width_out, this->num_output_ * this->out_spatial_dim_, this->out_spatial_dim_, width_out, 1); - cudnn::setConvolutionDesc(&conv_descs_[i], bottom_descs_[i], filter_desc_, pad_h, pad_w, stride_h, stride_w); + // Set cached descriptors + cudnn::setTensor4dDesc(&cached_bottom_descs_[i], + this->num_, + this->channels_ / this->group_, height, width, + this->channels_ * height * width, + height * width, width, 1); + cudnn::setConvolutionDesc(&cached_conv_descs_[i], + cached_bottom_descs_[i], + filter_desc_, pad_h, pad_w, stride_h, stride_w); + } + initialized_cached_descs_ = true; - // Have to pass full fwd/bwd cycle before taking the rest of memory - if (backward_passed_ctr_ > 1) { - // choose forward and backward algorithms + workspace(s) - CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(Caffe::cudnn_handle(), - bottom_descs_[i], filter_desc_, conv_descs_[i], top_descs_[i], - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_bytes, &fwd_algo_[i])); + // Ask cuDNN to find the best algorithm + if (use_algo_seeker_) { + size_t workspace_limit_bytes, total_memory; + GPUMemory::GetInfo(&workspace_limit_bytes, &total_memory); + // FindEx: A workspace of size workspace_bytes is allocated for FindEx. + // Besides, workspace, a buffer is allocated for the output of + // FindEx-backward-filter. The size of buffer is as big as weights. + // Get: workspace_bytes is only used as a workspace limit by Get. + // (no allocation happens before Get or by Get). + size_t workspace_bytes; + if (use_modest_workspace_) { + // In iteration 0, use a small amount of memory in order to leave + // most of memory for allocating layer blobs. + // TODO: Read 8*1024*1024 from a data member variable. + workspace_bytes = 8*1024*1024; + } else { + // Use 90% of available memory. + // Using all of memory may result in failure of workspace.reserve. + // TODO: Since 90% of memory might be too large, we can allocate + // exactly how much FindEx needs by taking the maximum + // workspace among all algorithms (requires an initial call + // to FindEx with workspace size 0). + // TODO: Read 0.9 from a data member variable. + workspace_bytes = workspace_limit_bytes * 0.9; + // Avoid seeking for an algorithm in subsequent iterations + use_algo_seeker_ = false; } + switch (this->layer_param_.convolution_param(). + cudnn_convolution_algo_seeker()) { + case ConvolutionParameter_CuDNNConvolutionAlgorithmSeeker_GET: + this->GetConvAlgo(bottom, top, workspace_bytes); + break; + case ConvolutionParameter_CuDNNConvolutionAlgorithmSeeker_FINDEX: + this->FindExConvAlgo(bottom, top, workspace_bytes); + break; + default: + LOG(ERROR) << "Wrong value for cudnn_convolution_algo_seeker"; + return; + } + } + // At this point, the algorithms and their workspace are set. + // Still need to query cuDNN for workspace size to check whether the + // selected algorithms are valid because: + // FindEx may return success while giving no valid algorithm as there + // may be no algorithm available for given parameters. + for (int i = 0; i < bottom.size(); i++) { + // forward algorithm CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(Caffe::cudnn_handle(), bottom_descs_[i], filter_desc_, conv_descs_[i], top_descs_[i], fwd_algo_[i], &(workspace_fwd_sizes_[i]))); - - if (backward_passed_ctr_ > 1) { - // choose backward algorithm for filter - CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( - Caffe::cudnn_handle(), - bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_bytes, &bwd_filter_algo_[i])); - } - - // get workspace for backwards filter algorithm + // backward filter algorithm CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize( Caffe::cudnn_handle(), bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_, bwd_filter_algo_[i], &workspace_bwd_filter_sizes_[i])); - - if (backward_passed_ctr_ > 1) { - // choose backward algo for data - CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( - Caffe::cudnn_handle(), - filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i], - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_bytes, &bwd_data_algo_[i])); - } - - // get workspace size + // backward data algorithm CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize( Caffe::cudnn_handle(), filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i], @@ -166,20 +236,227 @@ void CuDNNConvolutionLayer::Reshape( } } +template +void CuDNNConvolutionLayer::GetConvAlgo( + const vector*>& bottom, + const vector*>& top, + const size_t workspace_bytes) { + + for (int i = 0; i < bottom.size(); i++) { + // Get forward algorithm + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(Caffe::cudnn_handle(), + bottom_descs_[i], filter_desc_, conv_descs_[i], top_descs_[i], + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_bytes, &fwd_algo_[i])); + // Get backward filter algorithm + CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm( + Caffe::cudnn_handle(), + bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_bytes, &bwd_filter_algo_[i])); + // Get backward data algorithm + CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm( + Caffe::cudnn_handle(), + filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i], + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_bytes, &bwd_data_algo_[i])); + } +} + +template +void CuDNNConvolutionLayer::FindExConvAlgo( + const vector*>& bottom, + const vector*>& top, + const size_t workspace_bytes) { + + // Number of algorithms we want to consider + // Since we only consider one algorithm (the fastest), set this to 1 + const int kRequestAlgoCount = 1; + int fwd_algo_count; + int filter_algo_count; + int data_algo_count; + + cudnnConvolutionFwdAlgoPerf_t fwd_results[kRequestAlgoCount]; + cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_results[kRequestAlgoCount]; + cudnnConvolutionBwdDataAlgoPerf_t bwd_data_results[kRequestAlgoCount]; + + // Allocate temporary buffer for weights used for backward filter FindEx + void *tmp_weights; + const int tmp_weights_size = sizeof(Dtype) * weight_offset_; + GPUMemory::allocate(&tmp_weights, tmp_weights_size); + + // TODO: Try reducing workspace_bytes if it fails. + // In case, workspace_bytes is 90% of available memory, + // reduce it to 75%; if it fails again, reduce it to 50% and so on. + workspace.reserve(workspace_bytes); + + for (int i = 0; i < bottom.size(); i++) { + // Find forward algorithm + CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx( + Caffe::cudnn_handle(), + bottom_descs_[i], + bottom[i]->gpu_data(), + filter_desc_, + this->blobs_[0]->gpu_data(), + conv_descs_[i], + top_descs_[i], + top[i]->mutable_gpu_data(), + kRequestAlgoCount, + &fwd_algo_count, + fwd_results, + workspace.data(), + workspace.size())); + fwd_algo_[i] = fwd_results[0].algo; + workspace_fwd_sizes_[i] = fwd_results[0].memory; + + // Find backward filter algorithm + CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx( + Caffe::cudnn_handle(), + bottom_descs_[i], + bottom[i]->gpu_data(), + top_descs_[i], + top[i]->gpu_diff(), + conv_descs_[i], + filter_desc_, + tmp_weights, + kRequestAlgoCount, + &filter_algo_count, + bwd_filter_results, + workspace.data(), + workspace.size())); + bwd_filter_algo_[i] = bwd_filter_results[0].algo; + workspace_bwd_filter_sizes_[i] = bwd_filter_results[0].memory; + + // Find backward data algorithm + CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx( + Caffe::cudnn_handle(), + filter_desc_, + this->blobs_[0]->gpu_data(), + top_descs_[i], + top[i]->gpu_diff(), + conv_descs_[i], + bottom_descs_[i], + bottom[i]->mutable_gpu_diff(), + kRequestAlgoCount, + &data_algo_count, + bwd_data_results, + workspace.data(), + workspace.size())); + + bwd_data_algo_[i] = bwd_data_results[0].algo; + workspace_bwd_data_sizes_[i] = bwd_data_results[0].memory; + } + GPUMemory::deallocate(tmp_weights); + workspace.release(); +} + +// Checked if there is a difference between the corresponding descriptors in +// cached_bottom_descs_ and bottom_descs_. +// No need to compare all parameters: batchsize, height, and width are enough. +template +bool CuDNNConvolutionLayer::IsBottomDescChanged( + const vector*>& bottom) { + int cached_n; int cached_c; int cached_h; int cached_w; + int cached_stride_n; int cached_stride_c; + int cached_stride_h; int cached_stride_w; + int n; int c; int h; int w; + int stride_n; int stride_c; + int stride_h; int stride_w; + cudnnDataType_t type; + + for (int i = 0; i < bottom.size(); i++) { + CUDNN_CHECK(cudnnGetTensor4dDescriptor( + cached_bottom_descs_[i], + &type, + &cached_n, &cached_c, &cached_h, &cached_w, + &cached_stride_n, &cached_stride_c, + &cached_stride_h, &cached_stride_w)); + CUDNN_CHECK(cudnnGetTensor4dDescriptor( + bottom_descs_[i], + &type, + &n, &c, &h, &w, + &stride_n, &stride_c, + &stride_h, &stride_w)); + + if ((cached_n != n) || + (cached_c != c) || + (cached_h != h) || + (cached_w != w) || + (cached_stride_n != stride_n) || + (cached_stride_c != stride_c) || + (cached_stride_h != stride_h) || + (cached_stride_w != stride_w)) { + return true; + } + } + return false; +} + + +// Checked if there is a difference between the corresponding descriptors in +// cached_conv_descs_ and conv_descs_. +// No need to compare all parameters; pads, strides, and upscales are enough. +template +bool CuDNNConvolutionLayer::IsConvDescChanged( + const vector*>& bottom) { + int cached_padA[2]; + int padA[2]; + int cached_strideA[2]; + int strideA[2]; + int cached_upscaleA[2]; + int upscaleA[2]; + int arrayLength; + cudnnConvolutionMode_t mode; + cudnnDataType_t type; + + for (int i = 0; i < bottom.size(); i++) { + CUDNN_CHECK(cudnnGetConvolutionNdDescriptor( + cached_conv_descs_[i], + 2, + &arrayLength, + cached_padA, + cached_strideA, + cached_upscaleA, + &mode, + &type)); + CUDNN_CHECK(cudnnGetConvolutionNdDescriptor( + conv_descs_[i], + 2, + &arrayLength, + padA, + strideA, + upscaleA, + &mode, + &type)); + + if ((cached_padA[0] != padA[0]) || + (cached_padA[1] != padA[1]) || + (cached_strideA[0] != strideA[0]) || + (cached_strideA[1] != strideA[1]) || + (cached_upscaleA[0] != upscaleA[0]) || + (cached_upscaleA[1] != upscaleA[1])) { + return true; + } + } + return false; +} + template CuDNNConvolutionLayer::~CuDNNConvolutionLayer() { // Check that handles have been setup before destroying. if (!handles_setup_) { return; } for (int i = 0; i < bottom_descs_.size(); i++) { - cudnnDestroyTensorDescriptor(bottom_descs_[i]); - cudnnDestroyTensorDescriptor(top_descs_[i]); - cudnnDestroyConvolutionDescriptor(conv_descs_[i]); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(bottom_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(top_descs_[i])); + CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(cached_bottom_descs_[i])); + CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(cached_conv_descs_[i])); } if (this->bias_term_) { - cudnnDestroyTensorDescriptor(bias_desc_); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(bias_desc_)); } - cudnnDestroyFilterDescriptor(filter_desc_); + CUDNN_CHECK(cudnnDestroyFilterDescriptor(filter_desc_)); delete [] fwd_algo_; delete [] bwd_filter_algo_; diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu index f7fe33a3082..005eb8236b4 100644 --- a/src/caffe/layers/cudnn_conv_layer.cu +++ b/src/caffe/layers/cudnn_conv_layer.cu @@ -24,12 +24,14 @@ void CuDNNConvolutionLayer::Forward_gpu( size_t workspace_limit_bytes, total_memory; GPUMemory::GetInfo(&workspace_limit_bytes, &total_memory); if (workspace_fwd_sizes_[i] > workspace_limit_bytes) { + use_algo_seeker_ = true; this->Reshape(bottom, top); } // Sometimes closer to zero we might have memory info diverged from reality // If try_reserve fails, it updates the info internally and we proceed with // Reshape one more time if (!workspace.try_reserve(workspace_fwd_sizes_[i])) { + use_algo_seeker_ = true; this->Reshape(bottom, top); workspace.reserve(workspace_fwd_sizes_[i]); } @@ -63,6 +65,8 @@ void CuDNNConvolutionLayer::Forward_gpu( // NOLINT_NEXT_LINE(whitespace/operators) CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy)); } + // Possibly use faster algorithms by allowing larger workspace. + use_modest_workspace_ = false; } template @@ -84,9 +88,8 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, size_t workspace_limit_bytes, total_memory; GPUMemory::GetInfo(&workspace_limit_bytes, &total_memory); if (workspace_bwd_filter_sizes_[i] > workspace_limit_bytes || - workspace_bwd_data_sizes_[i] > workspace_limit_bytes || - // We need to get workspace sizes for the default algos at 1st run - backward_passed_ctr_ == 0) { + workspace_bwd_data_sizes_[i] > workspace_limit_bytes) { + use_algo_seeker_ = true; this->Reshape(bottom, top); } // To remove pressure on allocator, allocate the larger of the @@ -96,6 +99,7 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, // Reshape one more time if (!workspace.try_reserve(std::max(workspace_bwd_filter_sizes_[i], workspace_bwd_data_sizes_[i]))) { + use_algo_seeker_ = true; this->Reshape(bottom, top); workspace.reserve(std::max(workspace_bwd_filter_sizes_[i], workspace_bwd_data_sizes_[i])); @@ -146,7 +150,6 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, // NOLINT_NEXT_LINE(whitespace/operators) CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy)); } - ++backward_passed_ctr_; } INSTANTIATE_LAYER_GPU_FUNCS(CuDNNConvolutionLayer); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index bfc3669abaa..2ec2557b31e 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -690,6 +690,12 @@ message ConvolutionParameter { // implementation; for input blobs with num_axes != 2, this option is // ignored and the ND implementation will be used.) optional bool force_nd_im2col = 17 [default = false]; + enum CuDNNConvolutionAlgorithmSeeker { + GET = 0; + FINDEX = 1; + } + //Specifies which cudnn routine should be used to find the best convolution algorithm + optional CuDNNConvolutionAlgorithmSeeker cudnn_convolution_algo_seeker = 19 [default = FINDEX]; } message CropParameter { diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 2ef774313cc..e8f0733771a 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -354,16 +354,35 @@ int time() { // Instantiate the caffe net. Net caffe_net(FLAGS_model, caffe::TRAIN); - // Do a clean forward and backward pass, so that memory allocation are done + // Do a number of clean forward and backward pass, + // so that memory allocation are done, // and future iterations will be more stable. - LOG(INFO) << "Performing Forward"; + Timer forward_timer; + Timer backward_timer; + double forward_time = 0.0; + double backward_time = 0.0; + const int kInitIterations = 5; + LOG(INFO) << "Initialization for " << kInitIterations << " iterations."; // Note that for the speed benchmark, we will assume that the network does // not take any input blobs. + LOG(INFO) << "Performing Forward"; float initial_loss; - caffe_net.Forward(&initial_loss); + forward_timer.Start(); + for (int j = 0; j < kInitIterations; ++j) { + caffe_net.Forward(&initial_loss); + } + forward_time += forward_timer.MicroSeconds(); LOG(INFO) << "Initial loss: " << initial_loss; LOG(INFO) << "Performing Backward"; - caffe_net.Backward(); + backward_timer.Start(); + for (int j = 0; j < kInitIterations; ++j) { + caffe_net.Backward(); + } + backward_time += backward_timer.MicroSeconds(); + LOG(INFO) << "Average Initialization Forward pass: " << forward_time / + 1000 / kInitIterations << " ms."; + LOG(INFO) << "Average Initialization Backward pass: " << backward_time / + 1000 / kInitIterations << " ms."; const vector > >& layers = caffe_net.layers(); const vector*> >& bottom_vecs = caffe_net.bottom_vecs(); @@ -374,13 +393,11 @@ int time() { LOG(INFO) << "Testing for " << FLAGS_iterations << " iterations."; Timer total_timer; total_timer.Start(); - Timer forward_timer; - Timer backward_timer; Timer timer; std::vector forward_time_per_layer(layers.size(), 0.0); std::vector backward_time_per_layer(layers.size(), 0.0); - double forward_time = 0.0; - double backward_time = 0.0; + forward_time = 0.0; + backward_time = 0.0; for (int j = 0; j < FLAGS_iterations; ++j) { Timer iter_timer; iter_timer.Start();