Skip to content

Commit

Permalink
Add option not to reshape to Blob::FromProto; use when loading Blobs
Browse files Browse the repository at this point in the history
from saved NetParameter

Want to keep the param Blob shape the layer has set, and not necessarily
adopt the one from the saved net (e.g. want to keep new 1D bias shape,
rather than take the (1 x 1 x 1 x D) shape from a legacy net).
  • Loading branch information
jeffdonahue committed Feb 25, 2015
1 parent a359a43 commit a8023e2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class Blob {
Dtype* mutable_cpu_diff();
Dtype* mutable_gpu_diff();
void Update();
void FromProto(const BlobProto& proto);
void FromProto(const BlobProto& proto, bool reshape = true);
void ToProto(BlobProto* proto, bool write_diff = false) const;

/// @brief Compute the sum of absolute values (L1 norm) of the data.
Expand Down
6 changes: 4 additions & 2 deletions src/caffe/blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {
}

template <typename Dtype>
void Blob<Dtype>::FromProto(const BlobProto& proto) {
void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
vector<int> shape;
if (proto.has_num() || proto.has_channels() ||
proto.has_height() || proto.has_width()) {
Expand All @@ -450,7 +450,9 @@ void Blob<Dtype>::FromProto(const BlobProto& proto) {
shape[i] = proto.shape().dim(i);
}
}
Reshape(shape);
if (reshape) {
Reshape(shape);
}
// copy data
Dtype* data_vec = mutable_cpu_data();
for (int i = 0; i < count_; ++i) {
Expand Down
3 changes: 2 additions & 1 deletion src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,8 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
<< "Incompatible number of blobs for layer " << source_layer_name;
for (int j = 0; j < target_blobs.size(); ++j) {
CHECK(target_blobs[j]->ShapeEquals(source_layer.blobs(j)));
target_blobs[j]->FromProto(source_layer.blobs(j));
const bool kReshape = false;
target_blobs[j]->FromProto(source_layer.blobs(j), kReshape);
}
}
}
Expand Down

0 comments on commit a8023e2

Please sign in to comment.