Skip to content

Commit

Permalink
Modifications to Net to facilitate unrolled recurrent networks
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdonahue committed Feb 16, 2015
1 parent 25b87e9 commit 8983e6f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
10 changes: 10 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class Net {

/// @brief Updates the network weights based on the diff values computed.
void Update();
/**
* @brief Shares weight data of owner blobs with shared blobs.
*
* Note: this is called by Net::Init, and thus should normally not be
* called manually.
*/
void ShareWeightData();

/**
* @brief For an already initialized net, implicitly copies (i.e., using no
Expand Down Expand Up @@ -148,6 +155,9 @@ class Net {
return param_names_index_;
}
inline const vector<int>& param_owners() const { return param_owners_; }
inline const vector<string>& param_display_names() const {
return param_display_names_;
}
/// @brief Input and output blob numbers
inline int num_inputs() const { return net_input_blobs_.size(); }
inline int num_outputs() const { return net_output_blobs_.size(); }
Expand Down
42 changes: 11 additions & 31 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
layer_names_index_[layer_names_[layer_id]] = layer_id;
}
GetLearningRateAndWeightDecay();
ShareWeightData();
debug_info_ = param.debug_info();
LOG(INFO) << "Network initialization done.";
LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
Expand Down Expand Up @@ -444,8 +445,6 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
// Strict dimension checking -- all dims must be the same.
CHECK(this_blob->shape() == owner_blob->shape());
}
layers_[layer_id]->blobs()[param_id]->ShareData(
*layers_[owner_layer_id]->blobs()[owner_param_id]);
}
}

Expand Down Expand Up @@ -749,42 +748,23 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {

template <typename Dtype>
void Net<Dtype>::Update() {
// First, accumulate the diffs of any shared parameters into their owner's
// diff. (Assumes that the learning rate, weight decay, etc. have already been
// accounted for in the current diff.)
for (int i = 0; i < params_.size(); ++i) {
if (param_owners_[i] < 0) { continue; }
if (debug_info_) { UpdateDebugInfo(i); }
const int count = params_[i]->count();
const Dtype* this_diff;
Dtype* owner_diff;
switch (Caffe::mode()) {
case Caffe::CPU:
this_diff = params_[i]->cpu_diff();
owner_diff = params_[param_owners_[i]]->mutable_cpu_diff();
caffe_add(count, this_diff, owner_diff, owner_diff);
break;
#ifndef CPU_ONLY
case Caffe::GPU:
this_diff = params_[i]->gpu_diff();
owner_diff = params_[param_owners_[i]]->mutable_gpu_diff();
caffe_gpu_add(count, this_diff, owner_diff, owner_diff);
break;
#else
NO_GPU;
#endif
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
// Now, update the owned parameters.
// Update only the owned parameters.
for (int i = 0; i < params_.size(); ++i) {
if (param_owners_[i] >= 0) { continue; }
if (debug_info_) { UpdateDebugInfo(i); }
params_[i]->Update();
}
}

template <typename Dtype>
void Net<Dtype>::ShareWeightData() {
for (int i = 0; i < params_.size(); ++i) {
if (param_owners_[i] < 0) { continue; }
params_[i]->ShareData(*params_[param_owners_[i]]);
params_[i]->ShareDiff(*params_[param_owners_[i]]);
}
}

template <typename Dtype>
bool Net<Dtype>::has_blob(const string& blob_name) const {
return blob_names_index_.find(blob_name) != blob_names_index_.end();
Expand Down

0 comments on commit 8983e6f

Please sign in to comment.