Skip to content

Commit

Permalink
Blobs are ND arrays (for N not necessarily equals 4).
Browse files Browse the repository at this point in the history
vector<int> shape_ instead of (num, channels, height, width).
  • Loading branch information
jeffdonahue committed Mar 3, 2015
1 parent 4fba3da commit 1434e87
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 76 deletions.
146 changes: 125 additions & 21 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#ifndef CAFFE_BLOB_HPP_
#define CAFFE_BLOB_HPP_

#include <algorithm>
#include <string>
#include <vector>

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"

const int kMaxBlobAxes = INT_MAX;

namespace caffe {

/**
Expand All @@ -19,10 +25,16 @@ template <typename Dtype>
class Blob {
public:
Blob()
: data_(), diff_(), num_(0), channels_(0), height_(0), width_(0),
count_(0), capacity_(0) {}
: data_(), diff_(), count_(0), capacity_(0) {}

/// @brief Deprecated; use <code>Blob(const vector<int>& shape)</code>.
explicit Blob(const int num, const int channels, const int height,
const int width);
const int width);
explicit Blob(const vector<int>& shape);

/// @brief Deprecated; use <code>Reshape(const vector<int>& shape)</code>.
void Reshape(const int num, const int channels, const int height,
const int width);
/**
* @brief Change the dimensions of the blob, allocating new memory if
* necessary.
Expand All @@ -37,25 +49,118 @@ class Blob {
* an error; either Net::Forward or Net::Reshape need to be called to
* propagate the new input shape to higher layers.
*/
void Reshape(const int num, const int channels, const int height,
const int width);
void Reshape(const vector<int>& shape);
void ReshapeLike(const Blob& other);
inline int num() const { return num_; }
inline int channels() const { return channels_; }
inline int height() const { return height_; }
inline int width() const { return width_; }
inline string shape_string() const {
ostringstream stream;
for (int i = 0; i < shape_.size(); ++i) {
stream << shape_[i] << " ";
}
stream << "(" << count_ << ")";
return stream.str();
}
inline const vector<int>& shape() const { return shape_; }
/**
* @brief Returns the dimension of the index-th axis (or the negative index-th
* axis from the end, if index is negative).
*
* @param index the axis index, which may be negative as it will be
* "canonicalized" using CanonicalAxisIndex.
* Dies on out of range index.
*/
inline int shape(int index) const {
return shape_[CanonicalAxisIndex(index)];
}
inline int num_axes() const { return shape_.size(); }
inline int count() const { return count_; }

/**
* @brief Compute the volume of a slice; i.e., the product of dimensions
* among a range of axes.
*
* @param start_axis The first axis to include in the slice.
*
* @param end_axis The first axis to exclude from the slice.
*/
inline int count(int start_axis, int end_axis) const {
CHECK_LE(start_axis, end_axis);
CHECK_GE(start_axis, 0);
CHECK_GE(end_axis, 0);
CHECK_LE(start_axis, num_axes());
CHECK_LE(end_axis, num_axes());
int count = 1;
for (int i = start_axis; i < end_axis; ++i) {
count *= shape(i);
}
return count;
}
/**
* @brief Compute the volume of a slice spanning from a particular first
* axis to the final axis.
*
* @param start_axis The first axis to include in the slice.
*/
inline int count(int start_axis) const {
return count(start_axis, num_axes());
}

/**
* @brief Returns the 'canonical' version of a (usually) user-specified axis,
* allowing for negative indexing (e.g., -1 for the last axis).
*
* @param index the axis index.
* If 0 <= index < num_axes(), return index.
* If -num_axes <= index <= -1, return (num_axes() - (-index)),
* e.g., the last axis index (num_axes() - 1) if index == -1,
* the second to last if index == -2, etc.
* Dies on out of range index.
*/
inline int CanonicalAxisIndex(int axis_index) const {
CHECK_GE(axis_index, -num_axes())
<< "axis " << axis_index << " out of range for " << num_axes()
<< "-D Blob with shape " << shape_string();
CHECK_LT(axis_index, num_axes())
<< "axis " << axis_index << " out of range for " << num_axes()
<< "-D Blob with shape " << shape_string();
if (axis_index < 0) {
return axis_index + num_axes();
}
return axis_index;
}

/// @brief Deprecated legacy shape accessor num: use shape(0) instead.
inline int num() const { return LegacyShape(0); }
/// @brief Deprecated legacy shape accessor channels: use shape(1) instead.
inline int channels() const { return LegacyShape(1); }
/// @brief Deprecated legacy shape accessor height: use shape(2) instead.
inline int height() const { return LegacyShape(2); }
/// @brief Deprecated legacy shape accessor width: use shape(3) instead.
inline int width() const { return LegacyShape(3); }
inline int LegacyShape(int index) const {
CHECK_LE(num_axes(), 4)
<< "Cannot use legacy accessors on Blobs with > 4 axes.";
CHECK_LT(index, 4);
CHECK_GE(index, -4);
if (index >= num_axes() || index < -num_axes()) {
// Axis is out of range, but still in [0, 3] (or [-4, -1] for reverse
// indexing) -- this special case simulates the one-padding used to fill
// extraneous axes of legacy blobs.
return 1;
}
return shape(index);
}

inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
CHECK_LE(n, num_);
CHECK_GE(channels_, 0);
CHECK_LE(c, channels_);
CHECK_GE(height_, 0);
CHECK_LE(h, height_);
CHECK_GE(width_, 0);
CHECK_LE(w, width_);
return ((n * channels_ + c) * height_ + h) * width_ + w;
CHECK_LE(n, num());
CHECK_GE(channels(), 0);
CHECK_LE(c, channels());
CHECK_GE(height(), 0);
CHECK_LE(h, height());
CHECK_GE(width(), 0);
CHECK_LE(w, width());
return ((n * channels() + c) * height() + h) * width() + w;
}
/**
* @brief Copy from a source Blob.
Expand Down Expand Up @@ -135,13 +240,12 @@ class Blob {
*/
void ShareDiff(const Blob& other);

bool ShapeEquals(const BlobProto& other);

protected:
shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
int num_;
int channels_;
int height_;
int width_;
vector<int> shape_;
int count_;
int capacity_;

Expand Down
92 changes: 74 additions & 18 deletions src/caffe/blob.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#include <climits>
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
Expand All @@ -8,15 +11,24 @@ namespace caffe {
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
CHECK_GE(num, 0);
CHECK_GE(channels, 0);
CHECK_GE(height, 0);
CHECK_GE(width, 0);
num_ = num;
channels_ = channels;
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
vector<int> shape(4);
shape[0] = num;
shape[1] = channels;
shape[2] = height;
shape[3] = width;
Reshape(shape);
}

template <typename Dtype>
void Blob<Dtype>::Reshape(const vector<int>& shape) {
CHECK_LE(shape.size(), kMaxBlobAxes);
count_ = 1;
shape_.resize(shape.size());
for (int i = 0; i < shape.size(); ++i) {
CHECK_GE(shape[i], 0);
count_ *= shape[i];
shape_[i] = shape[i];
}
if (count_ > capacity_) {
capacity_ = count_;
data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
Expand All @@ -26,7 +38,7 @@ void Blob<Dtype>::Reshape(const int num, const int channels, const int height,

template <typename Dtype>
void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) {
Reshape(other.num(), other.channels(), other.height(), other.width());
Reshape(other.shape());
}

template <typename Dtype>
Expand All @@ -37,6 +49,13 @@ Blob<Dtype>::Blob(const int num, const int channels, const int height,
Reshape(num, channels, height, width);
}

template <typename Dtype>
Blob<Dtype>::Blob(const vector<int>& shape)
// capacity_ must be initialized before calling Reshape
: capacity_(0) {
Reshape(shape);
}

template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() const {
CHECK(data_);
Expand Down Expand Up @@ -345,12 +364,34 @@ void Blob<Dtype>::scale_diff(Dtype scale_factor) {
}
}

template <typename Dtype>
bool Blob<Dtype>::ShapeEquals(const BlobProto& other) {
if (other.has_num() || other.has_channels() ||
other.has_height() || other.has_width()) {
// Using deprecated 4D Blob dimensions --
// shape is (num, channels, height, width).
// Note: we do not use the normal Blob::num(), Blob::channels(), etc.
// methods as these index from the beginning of the blob shape, where legacy
// parameter blobs were indexed from the end of the blob shape (e.g., bias
// Blob shape (1 x 1 x 1 x N), IP layer weight Blob shape (1 x 1 x M x N)).
return shape_.size() <= 4 &&
LegacyShape(-4) == other.num() &&
LegacyShape(-3) == other.channels() &&
LegacyShape(-2) == other.height() &&
LegacyShape(-1) == other.width();
}
vector<int> other_shape(other.dim_size());
for (int i = 0; i < other.dim_size(); ++i) {
other_shape[i] = other.dim(i);
}
return shape_ == other_shape;
}

template <typename Dtype>
void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
if (num_ != source.num() || channels_ != source.channels() ||
height_ != source.height() || width_ != source.width()) {
if (source.count() != count_ || source.shape() != shape_) {
if (reshape) {
Reshape(source.num(), source.channels(), source.height(), source.width());
ReshapeLike(source);
} else {
LOG(FATAL) << "Trying to copy blobs of different sizes.";
}
Expand Down Expand Up @@ -381,7 +422,23 @@ void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {

template <typename Dtype>
void Blob<Dtype>::FromProto(const BlobProto& proto) {
Reshape(proto.num(), proto.channels(), proto.height(), proto.width());
vector<int> shape;
if (proto.has_num() || proto.has_channels() ||
proto.has_height() || proto.has_width()) {
// Using deprecated 4D Blob dimensions --
// shape is (num, channels, height, width).
shape.resize(4);
shape[0] = proto.num();
shape[1] = proto.channels();
shape[2] = proto.height();
shape[3] = proto.width();
} else {
shape.resize(proto.dim_size());
for (int i = 0; i < proto.dim_size(); ++i) {
shape[i] = proto.dim(i);
}
}
Reshape(shape);
// copy data
Dtype* data_vec = mutable_cpu_data();
for (int i = 0; i < count_; ++i) {
Expand All @@ -397,10 +454,9 @@ void Blob<Dtype>::FromProto(const BlobProto& proto) {

template <typename Dtype>
void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
proto->set_num(num_);
proto->set_channels(channels_);
proto->set_height(height_);
proto->set_width(width_);
for (int i = 0; i < shape_.size(); ++i) {
proto->add_dim(shape_[i]);
}
proto->clear_data();
proto->clear_diff();
const Dtype* data_vec = cpu_data();
Expand Down
25 changes: 4 additions & 21 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0));
}
blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id);
LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->num() << " "
<< top_vecs_[layer_id][top_id]->channels() << " "
<< top_vecs_[layer_id][top_id]->height() << " "
<< top_vecs_[layer_id][top_id]->width() << " ("
<< top_vecs_[layer_id][top_id]->count() << ")";
LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string();
if (layer->loss(top_id)) {
LOG(INFO) << " with loss weight " << layer->loss(top_id);
}
Expand Down Expand Up @@ -427,14 +423,7 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
<< "Shared parameter blobs must have the same count.";
} else {
// Strict dimension checking -- all dims must be the same.
CHECK_EQ(this_blob->num(), owner_blob->num())
<< "Shared parameter blobs must have the same num.";
CHECK_EQ(this_blob->channels(), owner_blob->channels())
<< "Shared parameter blobs must have the same channels.";
CHECK_EQ(this_blob->height(), owner_blob->height())
<< "Shared parameter blobs must have the same height.";
CHECK_EQ(this_blob->width(), owner_blob->width())
<< "Shared parameter blobs must have the same width.";
CHECK(this_blob->shape() == owner_blob->shape());
}
layers_[layer_id]->blobs()[param_id]->ShareData(
*layers_[owner_layer_id]->blobs()[owner_param_id]);
Expand Down Expand Up @@ -640,10 +629,7 @@ void Net<Dtype>::ShareTrainedLayersWith(const Net* other) {
<< "Incompatible number of blobs for layer " << source_layer_name;
for (int j = 0; j < target_blobs.size(); ++j) {
Blob<Dtype>* source_blob = source_layer->blobs()[j].get();
CHECK_EQ(target_blobs[j]->num(), source_blob->num());
CHECK_EQ(target_blobs[j]->channels(), source_blob->channels());
CHECK_EQ(target_blobs[j]->height(), source_blob->height());
CHECK_EQ(target_blobs[j]->width(), source_blob->width());
CHECK(target_blobs[j]->shape() == source_blob->shape());
target_blobs[j]->ShareData(*source_blob);
}
}
Expand Down Expand Up @@ -707,10 +693,7 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
<< "Incompatible number of blobs for layer " << source_layer_name;
for (int j = 0; j < target_blobs.size(); ++j) {
CHECK_EQ(target_blobs[j]->num(), source_layer.blobs(j).num());
CHECK_EQ(target_blobs[j]->channels(), source_layer.blobs(j).channels());
CHECK_EQ(target_blobs[j]->height(), source_layer.blobs(j).height());
CHECK_EQ(target_blobs[j]->width(), source_layer.blobs(j).width());
CHECK(target_blobs[j]->ShapeEquals(source_layer.blobs(j)));
target_blobs[j]->FromProto(source_layer.blobs(j));
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ syntax = "proto2";
package caffe;

message BlobProto {
repeated int32 dim = 7 [packed = true];
repeated float data = 5 [packed = true];
repeated float diff = 6 [packed = true];

// 4D dimensions -- deprecated. Use "dim" instead.
optional int32 num = 1 [default = 0];
optional int32 channels = 2 [default = 0];
optional int32 height = 3 [default = 0];
optional int32 width = 4 [default = 0];
repeated float data = 5 [packed = true];
repeated float diff = 6 [packed = true];
}

// The BlobProtoVector is simply a way to pass multiple blobproto instances
Expand Down
Loading

0 comments on commit 1434e87

Please sign in to comment.