Skip to content

Commit

Permalink
Merge pull request BVLC#2009 from philkr/memory
Browse files Browse the repository at this point in the history
Saving memory by reusing col_buffer_.
  • Loading branch information
coltrane89 committed Apr 8, 2015
2 parents 7e0fb7c + f8b9cfc commit 57eb3e9
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 3 deletions.
46 changes: 46 additions & 0 deletions include/caffe/tempmem.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef CAFFE_TEMPMEM_HPP_
#define CAFFE_TEMPMEM_HPP_

#include <cstdlib>

#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

/**
* @brief Holds a block of temporary memory that can be shared between
* different parts of caffe. The CPU and GPU memory is *not*
* synchronized.
*
* TODO(dox): more thorough description.
*/
template<typename Dtype>
class TemporaryMemory {
public:
TemporaryMemory():cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0) {}
explicit TemporaryMemory(size_t size);
~TemporaryMemory();

void acquire_gpu();
void release_gpu();
void acquire_cpu();
void release_cpu();
const Dtype* cpu_data() const;
const Dtype* gpu_data() const;
Dtype* mutable_cpu_data();
Dtype* mutable_gpu_data();
size_t size() { return size_; }
void resize(size_t size);
private:
Dtype* cpu_ptr_;
Dtype* gpu_ptr_;
size_t size_;

DISABLE_COPY_AND_ASSIGN(TemporaryMemory);
}; // class TemporaryMemory

} // namespace caffe

#endif // CAFFE_TEMPMEM_HPP_
3 changes: 2 additions & 1 deletion include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "caffe/loss_layers.hpp"
#include "caffe/neuron_layers.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/tempmem.hpp"

namespace caffe {

Expand Down Expand Up @@ -107,7 +108,7 @@ class BaseConvolutionLayer : public Layer<Dtype> {
int col_offset_;
int output_offset_;

Blob<Dtype> col_buffer_;
TemporaryMemory<Dtype> col_buffer_;
Blob<Dtype> bias_multiplier_;
};

Expand Down
16 changes: 14 additions & 2 deletions src/caffe/layers/base_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ void BaseConvolutionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
// overly large memory usage. In the special case of 1x1 convolution
// it goes lazily unused to save memory.
if (reverse_dimensions()) {
col_buffer_.Reshape(1, kernel_dim_, height_, width_);
col_buffer_.resize(kernel_dim_*height_*width_);
} else {
col_buffer_.Reshape(1, kernel_dim_, height_out_, width_out_);
col_buffer_.resize(kernel_dim_*height_out_*width_out_);
}
// Set up the all ones "bias multiplier" for adding biases by BLAS
if (bias_term_) {
Expand All @@ -159,6 +159,7 @@ template <typename Dtype>
void BaseConvolutionLayer<Dtype>::forward_cpu_gemm(const Dtype* input,
const Dtype* weights, Dtype* output, bool skip_im2col) {
const Dtype* col_buff = input;
col_buffer_.acquire_cpu();
if (!is_1x1_) {
if (!skip_im2col) {
conv_im2col_cpu(input, col_buffer_.mutable_cpu_data());
Expand All @@ -171,6 +172,7 @@ void BaseConvolutionLayer<Dtype>::forward_cpu_gemm(const Dtype* input,
(Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g,
(Dtype)0., output + output_offset_ * g);
}
col_buffer_.release_cpu();
}

template <typename Dtype>
Expand All @@ -184,6 +186,7 @@ void BaseConvolutionLayer<Dtype>::forward_cpu_bias(Dtype* output,
template <typename Dtype>
void BaseConvolutionLayer<Dtype>::backward_cpu_gemm(const Dtype* output,
const Dtype* weights, Dtype* input) {
col_buffer_.acquire_cpu();
Dtype* col_buff = col_buffer_.mutable_cpu_data();
if (is_1x1_) {
col_buff = input;
Expand All @@ -197,11 +200,13 @@ void BaseConvolutionLayer<Dtype>::backward_cpu_gemm(const Dtype* output,
if (!is_1x1_) {
conv_col2im_cpu(col_buff, input);
}
col_buffer_.release_cpu();
}

template <typename Dtype>
void BaseConvolutionLayer<Dtype>::weight_cpu_gemm(const Dtype* input,
const Dtype* output, Dtype* weights) {
col_buffer_.acquire_cpu();
const Dtype* col_buff = input;
if (!is_1x1_) {
conv_im2col_cpu(input, col_buffer_.mutable_cpu_data());
Expand All @@ -213,6 +218,7 @@ void BaseConvolutionLayer<Dtype>::weight_cpu_gemm(const Dtype* input,
(Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g,
(Dtype)1., weights + weight_offset_ * g);
}
col_buffer_.release_cpu();
}

template <typename Dtype>
Expand All @@ -227,6 +233,7 @@ void BaseConvolutionLayer<Dtype>::backward_cpu_bias(Dtype* bias,
template <typename Dtype>
void BaseConvolutionLayer<Dtype>::forward_gpu_gemm(const Dtype* input,
const Dtype* weights, Dtype* output, bool skip_im2col) {
col_buffer_.acquire_gpu();
const Dtype* col_buff = input;
if (!is_1x1_) {
if (!skip_im2col) {
Expand All @@ -240,6 +247,7 @@ void BaseConvolutionLayer<Dtype>::forward_gpu_gemm(const Dtype* input,
(Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g,
(Dtype)0., output + output_offset_ * g);
}
col_buffer_.release_gpu();
}

template <typename Dtype>
Expand All @@ -253,6 +261,7 @@ void BaseConvolutionLayer<Dtype>::forward_gpu_bias(Dtype* output,
template <typename Dtype>
void BaseConvolutionLayer<Dtype>::backward_gpu_gemm(const Dtype* output,
const Dtype* weights, Dtype* input) {
col_buffer_.acquire_gpu();
Dtype* col_buff = col_buffer_.mutable_gpu_data();
if (is_1x1_) {
col_buff = input;
Expand All @@ -266,11 +275,13 @@ void BaseConvolutionLayer<Dtype>::backward_gpu_gemm(const Dtype* output,
if (!is_1x1_) {
conv_col2im_gpu(col_buff, input);
}
col_buffer_.release_gpu();
}

template <typename Dtype>
void BaseConvolutionLayer<Dtype>::weight_gpu_gemm(const Dtype* input,
const Dtype* output, Dtype* weights) {
col_buffer_.acquire_gpu();
const Dtype* col_buff = input;
if (!is_1x1_) {
conv_im2col_gpu(input, col_buffer_.mutable_gpu_data());
Expand All @@ -282,6 +293,7 @@ void BaseConvolutionLayer<Dtype>::weight_gpu_gemm(const Dtype* input,
(Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g,
(Dtype)1., weights + weight_offset_ * g);
}
col_buffer_.release_gpu();
}

template <typename Dtype>
Expand Down
191 changes: 191 additions & 0 deletions src/caffe/tempmem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#include <boost/shared_ptr.hpp>
#include <boost/smart_ptr/make_shared.hpp>
#include <boost/thread/locks.hpp>
#include <boost/thread/mutex.hpp>
#include <cstring>
#include <vector>

#include "caffe/common.hpp"
#include "caffe/tempmem.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {

template<bool gpu>
struct TemporaryMemoryAllocator {
};
#ifndef CPU_ONLY
// GPU allocator
template<>
struct TemporaryMemoryAllocator<true> {
static void * calloc(size_t size) {
void * p;
CUDA_CHECK(cudaMalloc(&p, size));
caffe_gpu_memset(size, 0, p);
return p;
}
static void free(void * p) {
cudaFree(p);
}
};
#endif
// CPU allocator
template<>
struct TemporaryMemoryAllocator<false> {
static void * calloc(size_t size) {
void * p;
CaffeMallocHost(&p, size);
caffe_memset(size, 0, p);
return p;
}
static void free(void * p) {
CaffeFreeHost(p);
}
};

template<bool gpu>
class GlobalTemporaryMemory {
private:
class Block{
private:
Block(const Block& o):data_(NULL), size_(0), is_locked_(false) {}

public:
void * data_;
size_t size_;
bool is_locked_;
Block():data_(NULL), size_(0), is_locked_(false) {}
~Block() {
if (data_)
TemporaryMemoryAllocator<gpu>::free(data_);
}
void * try_lock(size_t max_size) {
if (is_locked_) return NULL;
is_locked_ = true;
if (size_ < max_size) {
size_ = max_size;
if (data_)
TemporaryMemoryAllocator<gpu>::free(data_);
data_ = TemporaryMemoryAllocator<gpu>::calloc(size_);
}
return data_;
}
void unlock() {
is_locked_ = false;
}
};
std::vector< boost::shared_ptr<Block> > blocks_;
size_t max_size_;
boost::mutex mtx_;
GlobalTemporaryMemory(const GlobalTemporaryMemory & o):max_size_(0) {}

public:
GlobalTemporaryMemory():max_size_(0) {}
void * lock() {
// Note: Currently concurrent accesses allocate duplicate memory of
// max_size_ in order to reduce the need to reallocate memory
// This might be a bit wasteful.
boost::lock_guard<boost::mutex> guard(mtx_);
for (int i = 0; i < blocks_.size(); i++) {
void * r = blocks_[i]->try_lock(max_size_);
if (r) return r;
}
blocks_.push_back(boost::make_shared<Block>());
return blocks_.back()->try_lock(max_size_);
}
template<typename Dtype>
Dtype * lock() {
return static_cast<Dtype*>(lock());
}
void unlock(void * mem) {
boost::lock_guard<boost::mutex> guard(mtx_);
for (int i = 0; i < blocks_.size(); i++)
if (blocks_[i]->is_locked_ && blocks_[i]->data_ == mem) {
blocks_[i]->unlock();
return;
}
LOG(WARNING) << "Unlock failed! Lost the memory block!";
}
void allocate(size_t size) {
boost::lock_guard<boost::mutex> guard(mtx_);
if (size > max_size_)
max_size_ = size;
}
};
#ifndef CPU_ONLY
static GlobalTemporaryMemory<true> gpu_memory_;
#endif
static GlobalTemporaryMemory<false> cpu_memory_;

template<typename Dtype>
TemporaryMemory<Dtype>::TemporaryMemory(size_t size):cpu_ptr_(NULL),
gpu_ptr_(NULL), size_(0) {
resize(size_);
}
template<typename Dtype>
TemporaryMemory<Dtype>::~TemporaryMemory() {
}

template<typename Dtype>
void TemporaryMemory<Dtype>::acquire_cpu() {
cpu_ptr_ = cpu_memory_.lock<Dtype>();
CHECK(cpu_ptr_ != NULL) << "acquire failed!";
}
template<typename Dtype>
void TemporaryMemory<Dtype>::acquire_gpu() {
#ifdef CPU_ONLY
NO_GPU;
#else
gpu_ptr_ = gpu_memory_.lock<Dtype>();
CHECK(gpu_ptr_ != NULL) << "acquire failed!";
#endif
}
template<typename Dtype>
void TemporaryMemory<Dtype>::release_cpu() {
CHECK(cpu_ptr_ != NULL) << "Need to allocate and acquire the data first";
cpu_memory_.unlock(cpu_ptr_);
cpu_ptr_ = NULL;
}
template<typename Dtype>
void TemporaryMemory<Dtype>::release_gpu() {
#ifdef CPU_ONLY
NO_GPU;
#else
CHECK(gpu_ptr_ != NULL) << "Need to allocate and acquire the data first";
gpu_memory_.unlock(gpu_ptr_);
gpu_ptr_ = NULL;
#endif
}
template<typename Dtype>
const Dtype* TemporaryMemory<Dtype>::cpu_data() const {
CHECK(cpu_ptr_ != NULL) << "Need to allocate and acquire the data first";
return cpu_ptr_;
}
template<typename Dtype>
const Dtype* TemporaryMemory<Dtype>::gpu_data() const {
CHECK(gpu_ptr_ != NULL) << "Need to allocate and acquire the data first";
return gpu_ptr_;
}
template<typename Dtype>
Dtype* TemporaryMemory<Dtype>::mutable_cpu_data() {
CHECK(cpu_ptr_ != NULL) << "Need to allocate and acquire the data first";
return cpu_ptr_;
}
template<typename Dtype>
Dtype* TemporaryMemory<Dtype>::mutable_gpu_data() {
CHECK(gpu_ptr_ != NULL) << "Need to allocate and acquire the data first";
return gpu_ptr_;
}
template<typename Dtype>
void TemporaryMemory<Dtype>::resize(size_t size) {
size_ = size;
#ifndef CPU_ONLY
gpu_memory_.allocate(size_*sizeof(Dtype));
#endif
cpu_memory_.allocate(size_*sizeof(Dtype));
}

INSTANTIATE_CLASS(TemporaryMemory);
template class TemporaryMemory<int>;
template class TemporaryMemory<unsigned int>;

} // namespace caffe

0 comments on commit 57eb3e9

Please sign in to comment.