Skip to content

Commit

Permalink
Merge pull request #2866 from jeffdonahue/fix-weight-sharing
Browse files Browse the repository at this point in the history
Fix weight sharing
  • Loading branch information
jeffdonahue committed Aug 7, 2015
2 parents fc77ef3 + d5b42bf commit 32ced4f
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 136 deletions.
41 changes: 35 additions & 6 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ class Net {
*/
string Forward(const string& input_blob_protos, Dtype* loss = NULL);

/**
* @brief Zeroes out the diffs of all net parameters.
* Should be run before Backward.
*/
void ClearParamDiffs();

/**
* The network backward should take no input and output, since it solely
* computes the gradient w.r.t the parameters, and the data has already been
Expand Down Expand Up @@ -84,6 +90,13 @@ class Net {

/// @brief Updates the network weights based on the diff values computed.
void Update();
/**
* @brief Shares weight data of owner blobs with shared blobs.
*
* Note: this is called by Net::Init, and thus should normally not be
* called manually.
*/
void ShareWeights();

/**
* @brief For an already initialized net, implicitly copies (i.e., using no
Expand Down Expand Up @@ -148,11 +161,19 @@ class Net {
inline const vector<shared_ptr<Blob<Dtype> > >& params() const {
return params_;
}
/// @brief returns the parameter learning rate multipliers
inline const vector<Blob<Dtype>*>& learnable_params() const {
return learnable_params_;
}
/// @brief returns the learnable parameter learning rate multipliers
inline const vector<float>& params_lr() const { return params_lr_; }
inline const vector<bool>& has_params_lr() const { return has_params_lr_; }
/// @brief returns the learnable parameter decay multipliers
inline const vector<float>& params_weight_decay() const {
return params_weight_decay_;
}
inline const vector<bool>& has_params_decay() const {
return has_params_decay_;
}
const map<string, int>& param_names_index() const {
return param_names_index_;
}
Expand Down Expand Up @@ -213,9 +234,6 @@ class Net {
/// @brief Helper for displaying debug info in Update.
void UpdateDebugInfo(const int param_id);

/// @brief Get misc parameters, e.g. the LR multiplier and weight decay.
void GetLearningRateAndWeightDecay();

/// @brief The network name
string name_;
/// @brief The phase: TRAIN or TEST
Expand Down Expand Up @@ -254,10 +272,21 @@ class Net {
vector<Blob<Dtype>*> net_output_blobs_;
/// The parameters in the network.
vector<shared_ptr<Blob<Dtype> > > params_;
/// the learning rate multipliers
vector<Blob<Dtype>*> learnable_params_;
/**
* The mapping from params_ -> learnable_params_: we have
* learnable_param_ids_.size() == params_.size(),
* and learnable_params_[learnable_param_ids_[i]] == params_[i].get()
* if and only if params_[i] is an "owner"; otherwise, params_[i] is a sharer
* and learnable_params_[learnable_param_ids_[i]] gives its owner.
*/
vector<int> learnable_param_ids_;
/// the learning rate multipliers for learnable_params_
vector<float> params_lr_;
/// the weight decay multipliers
vector<bool> has_params_lr_;
/// the weight decay multipliers for learnable_params_
vector<float> params_weight_decay_;
vector<bool> has_params_decay_;
/// The bytes of memory used by this net
size_t memory_used_;
/// Whether to compute and display debug info for the net.
Expand Down
89 changes: 50 additions & 39 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
for (size_t layer_id = 0; layer_id < layer_names_.size(); ++layer_id) {
layer_names_index_[layer_names_[layer_id]] = layer_id;
}
GetLearningRateAndWeightDecay();
ShareWeights();
debug_info_ = param.debug_info();
LOG(INFO) << "Network initialization done.";
LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
Expand Down Expand Up @@ -441,6 +441,9 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
params_.push_back(layers_[layer_id]->blobs()[param_id]);
param_id_vecs_[layer_id].push_back(net_param_id);
param_layer_indices_.push_back(make_pair(layer_id, param_id));
ParamSpec default_param_spec;
const ParamSpec* param_spec = (layer_param.param_size() > param_id) ?
&layer_param.param(param_id) : &default_param_spec;
if (!param_size || !param_name.size() || (param_name.size() &&
param_names_index_.find(param_name) == param_names_index_.end())) {
// This layer "owns" this parameter blob -- it is either anonymous
Expand All @@ -450,6 +453,13 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
if (param_name.size()) {
param_names_index_[param_name] = net_param_id;
}
const int learnable_param_id = learnable_params_.size();
learnable_params_.push_back(params_[net_param_id].get());
learnable_param_ids_.push_back(learnable_param_id);
has_params_lr_.push_back(param_spec->has_lr_mult());
has_params_decay_.push_back(param_spec->has_decay_mult());
params_lr_.push_back(param_spec->lr_mult());
params_weight_decay_.push_back(param_spec->decay_mult());
} else {
// Named param blob with name we've seen before: share params
const int owner_net_param_id = param_names_index_[param_name];
Expand All @@ -474,23 +484,25 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
// Strict dimension checking -- all dims must be the same.
CHECK(this_blob->shape() == owner_blob->shape());
}
layers_[layer_id]->blobs()[param_id]->ShareData(
*layers_[owner_layer_id]->blobs()[owner_param_id]);
}
}

template <typename Dtype>
void Net<Dtype>::GetLearningRateAndWeightDecay() {
LOG(INFO) << "Collecting Learning Rate and Weight Decay.";
ParamSpec default_param_spec;
for (int i = 0; i < layers_.size(); ++i) {
vector<shared_ptr<Blob<Dtype> > >& layer_blobs = layers_[i]->blobs();
for (int j = 0; j < layer_blobs.size(); ++j) {
const ParamSpec* param_spec =
(layers_[i]->layer_param().param_size() > j) ?
&layers_[i]->layer_param().param(j) : &default_param_spec;
params_lr_.push_back(param_spec->lr_mult());
params_weight_decay_.push_back(param_spec->decay_mult());
const int learnable_param_id = learnable_param_ids_[owner_net_param_id];
if (param_spec->has_lr_mult()) {
if (has_params_lr_[learnable_param_id]) {
CHECK_EQ(param_spec->lr_mult(), params_lr_[learnable_param_id])
<< "Shared param '" << param_name << "' has mismatched lr_mult.";
} else {
has_params_lr_[learnable_param_id] = true;
params_lr_[learnable_param_id] = param_spec->lr_mult();
}
}
if (param_spec->has_decay_mult()) {
if (has_params_decay_[learnable_param_id]) {
CHECK_EQ(param_spec->decay_mult(),
params_weight_decay_[learnable_param_id])
<< "Shared param '" << param_name << "' has mismatched decay_mult.";
} else {
has_params_decay_[learnable_param_id] = true;
params_weight_decay_[learnable_param_id] = param_spec->decay_mult();
}
}
}
}
Expand Down Expand Up @@ -895,39 +907,38 @@ void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {

template <typename Dtype>
void Net<Dtype>::Update() {
// First, accumulate the diffs of any shared parameters into their owner's
// diff. (Assumes that the learning rate, weight decay, etc. have already been
// accounted for in the current diff.)
for (int i = 0; i < params_.size(); ++i) {
if (param_owners_[i] < 0) { continue; }
if (debug_info_) { UpdateDebugInfo(i); }
const int count = params_[i]->count();
const Dtype* this_diff;
Dtype* owner_diff;
for (int i = 0; i < learnable_params_.size(); ++i) {
learnable_params_[i]->Update();
}
}

template <typename Dtype>
void Net<Dtype>::ClearParamDiffs() {
for (int i = 0; i < learnable_params_.size(); ++i) {
Blob<Dtype>* blob = learnable_params_[i];
switch (Caffe::mode()) {
case Caffe::CPU:
this_diff = params_[i]->cpu_diff();
owner_diff = params_[param_owners_[i]]->mutable_cpu_diff();
caffe_add(count, this_diff, owner_diff, owner_diff);
caffe_set(blob->count(), static_cast<Dtype>(0),
blob->mutable_cpu_diff());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
this_diff = params_[i]->gpu_diff();
owner_diff = params_[param_owners_[i]]->mutable_gpu_diff();
caffe_gpu_add(count, this_diff, owner_diff, owner_diff);
caffe_gpu_set(blob->count(), static_cast<Dtype>(0),
blob->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
// Now, update the owned parameters.
}

template <typename Dtype>
void Net<Dtype>::ShareWeights() {
for (int i = 0; i < params_.size(); ++i) {
if (param_owners_[i] >= 0) { continue; }
if (debug_info_) { UpdateDebugInfo(i); }
params_[i]->Update();
if (param_owners_[i] < 0) { continue; }
params_[i]->ShareData(*params_[param_owners_[i]]);
params_[i]->ShareDiff(*params_[param_owners_[i]]);
}
}

Expand Down
44 changes: 12 additions & 32 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,24 +173,7 @@ void Solver<Dtype>::Step(int iters) {

while (iter_ < stop_iter) {
// zero-init the params
for (int i = 0; i < net_->params().size(); ++i) {
shared_ptr<Blob<Dtype> > blob = net_->params()[i];
switch (Caffe::mode()) {
case Caffe::CPU:
caffe_set(blob->count(), static_cast<Dtype>(0),
blob->mutable_cpu_diff());
break;
case Caffe::GPU:
#ifndef CPU_ONLY
caffe_gpu_set(blob->count(), static_cast<Dtype>(0),
blob->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
}

net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())) {
TestAll();
Expand Down Expand Up @@ -462,7 +445,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
// Initialize the history
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
history_.clear();
update_.clear();
temp_.clear();
Expand All @@ -478,12 +461,10 @@ template <typename Dtype>
void SGDSolver<Dtype>::ClipGradients() {
const Dtype clip_gradients = this->param_.clip_gradients();
if (clip_gradients < 0) { return; }
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
Dtype sumsq_diff = 0;
for (int i = 0; i < net_params.size(); ++i) {
if (this->net_->param_owners()[i] < 0) {
sumsq_diff += net_params[i]->sumsq_diff();
}
sumsq_diff += net_params[i]->sumsq_diff();
}
const Dtype l2norm_diff = std::sqrt(sumsq_diff);
if (l2norm_diff > clip_gradients) {
Expand All @@ -492,9 +473,7 @@ void SGDSolver<Dtype>::ClipGradients() {
<< l2norm_diff << " > " << clip_gradients << ") "
<< "by scale factor " << scale_factor;
for (int i = 0; i < net_params.size(); ++i) {
if (this->net_->param_owners()[i] < 0) {
net_params[i]->scale_diff(scale_factor);
}
net_params[i]->scale_diff(scale_factor);
}
}
}
Expand All @@ -506,7 +485,8 @@ void SGDSolver<Dtype>::ApplyUpdate() {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
ClipGradients();
for (int param_id = 0; param_id < this->net_->params().size(); ++param_id) {
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
Normalize(param_id);
Regularize(param_id);
ComputeUpdateValue(param_id, rate);
Expand All @@ -518,7 +498,7 @@ template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
if (this->param_.iter_size() == 1) { return; }
// Scale gradient to counterbalance accumulation.
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
switch (Caffe::mode()) {
case Caffe::CPU: {
Expand All @@ -542,7 +522,7 @@ void SGDSolver<Dtype>::Normalize(int param_id) {

template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype weight_decay = this->param_.weight_decay();
Expand Down Expand Up @@ -604,7 +584,7 @@ void SGDSolver<Dtype>::Regularize(int param_id) {

template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
Expand Down Expand Up @@ -743,7 +723,7 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {

template <typename Dtype>
void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
Expand Down Expand Up @@ -803,7 +783,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {

template <typename Dtype>
void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype delta = this->param_.delta();
Dtype local_rate = rate * net_params_lr[param_id];
Expand Down
Loading

0 comments on commit 32ced4f

Please sign in to comment.