Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snapshot model weights/solver state to HDF5 files #2836

Merged
merged 5 commits into from
Aug 7, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ Makefile.config
data/*
models/*
*.caffemodel
*.caffemodel.h5
*.solverstate
*.solverstate.h5
*.binaryproto
*leveldb
*lmdb
Expand Down
4 changes: 2 additions & 2 deletions examples/cifar10/train_full.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ $TOOLS/caffe train \
# reduce learning rate by factor of 10
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_full_solver_lr1.prototxt \
--snapshot=examples/cifar10/cifar10_full_iter_60000.solverstate
--snapshot=examples/cifar10/cifar10_full_iter_60000.solverstate.h5

# reduce learning rate by factor of 10
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_full_solver_lr2.prototxt \
--snapshot=examples/cifar10/cifar10_full_iter_65000.solverstate
--snapshot=examples/cifar10/cifar10_full_iter_65000.solverstate.h5
2 changes: 1 addition & 1 deletion examples/cifar10/train_quick.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ $TOOLS/caffe train \
# reduce learning rate by factor of 10 after 8 epochs
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_quick_solver_lr1.prototxt \
--snapshot=examples/cifar10/cifar10_quick_iter_4000.solverstate
--snapshot=examples/cifar10/cifar10_quick_iter_4000.solverstate.h5
2 changes: 1 addition & 1 deletion examples/imagenet/resume_training.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

./build/tools/caffe train \
--solver=models/bvlc_reference_caffenet/solver.prototxt \
--snapshot=models/bvlc_reference_caffenet/caffenet_train_10000.solverstate
--snapshot=models/bvlc_reference_caffenet/caffenet_train_10000.solverstate.h5
2 changes: 1 addition & 1 deletion include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"

const int kMaxBlobAxes = INT_MAX;
const int kMaxBlobAxes = 32;

namespace caffe {

Expand Down
4 changes: 4 additions & 0 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ class Net {
*/
void CopyTrainedLayersFrom(const NetParameter& param);
void CopyTrainedLayersFrom(const string trained_filename);
void CopyTrainedLayersFromBinaryProto(const string trained_filename);
void CopyTrainedLayersFromHDF5(const string trained_filename);
/// @brief Writes the net to a proto.
void ToProto(NetParameter* param, bool write_diff = false) const;
/// @brief Writes the net to an HDF5 file.
void ToHDF5(const string& filename, bool write_diff = false) const;

/// @brief returns the network name.
inline const string& name() const { return name_; }
Expand Down
21 changes: 14 additions & 7 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class Solver {
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
// The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState()
// function that restores the state from a SolverState protocol buffer.
// The Restore method simply dispatches to one of the
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
Expand All @@ -46,11 +46,15 @@ class Solver {
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(SolverState* state) = 0;
virtual void RestoreSolverState(const SolverState& state) = 0;
virtual void SnapshotSolverState(const string& model_filename) = 0;
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);

SolverParameter param_;
Expand Down Expand Up @@ -85,8 +89,11 @@ class SGDSolver : public Solver<Dtype> {
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
virtual void SnapshotSolverState(SolverState * state);
virtual void RestoreSolverState(const SolverState& state);
virtual void SnapshotSolverState(const string& model_filename);
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
virtual void SnapshotSolverStateToHDF5(const string& model_filename);
virtual void RestoreSolverStateFromHDF5(const string& state_file);
virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
// history maintains the historical momentum data.
// update maintains update related data and is not needed in snapshots.
// temp maintains other information that might be needed in computation
Expand Down
39 changes: 39 additions & 0 deletions include/caffe/util/hdf5.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef CAFFE_UTIL_HDF5_H_
#define CAFFE_UTIL_HDF5_H_

#include <string>

#include "hdf5.h"
#include "hdf5_hl.h"

#include "caffe/blob.hpp"

namespace caffe {

template <typename Dtype>
void hdf5_load_nd_dataset_helper(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
Blob<Dtype>* blob);

template <typename Dtype>
void hdf5_load_nd_dataset(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
Blob<Dtype>* blob);

template <typename Dtype>
void hdf5_save_nd_dataset(
const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob,
bool write_diff = false);

int hdf5_load_int(hid_t loc_id, const string& dataset_name);
void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i);
string hdf5_load_string(hid_t loc_id, const string& dataset_name);
void hdf5_save_string(hid_t loc_id, const string& dataset_name,
const string& s);

int hdf5_get_num_links(hid_t loc_id);
string hdf5_get_name_by_idx(hid_t loc_id, int idx);

} // namespace caffe

#endif // CAFFE_UTIL_HDF5_H_
18 changes: 0 additions & 18 deletions include/caffe/util/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
#include <string>

#include "google/protobuf/message.h"
#include "hdf5.h"
#include "hdf5_hl.h"

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"

#define HDF5_NUM_DIMS 4

namespace caffe {

using ::google::protobuf::Message;
Expand Down Expand Up @@ -140,20 +136,6 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);

void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);

template <typename Dtype>
void hdf5_load_nd_dataset_helper(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
Blob<Dtype>* blob);

template <typename Dtype>
void hdf5_load_nd_dataset(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
Blob<Dtype>* blob);

template <typename Dtype>
void hdf5_save_nd_dataset(
const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob);

} // namespace caffe

#endif // CAFFE_UTIL_IO_H_
49 changes: 42 additions & 7 deletions src/caffe/blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,31 +456,66 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
}
// copy data
Dtype* data_vec = mutable_cpu_data();
for (int i = 0; i < count_; ++i) {
data_vec[i] = proto.data(i);
if (proto.double_data_size() > 0) {
CHECK_EQ(count_, proto.double_data_size());
for (int i = 0; i < count_; ++i) {
data_vec[i] = proto.double_data(i);
}
} else {
CHECK_EQ(count_, proto.data_size());
for (int i = 0; i < count_; ++i) {
data_vec[i] = proto.data(i);
}
}
if (proto.diff_size() > 0) {
if (proto.double_diff_size() > 0) {
CHECK_EQ(count_, proto.double_diff_size());
Dtype* diff_vec = mutable_cpu_diff();
for (int i = 0; i < count_; ++i) {
diff_vec[i] = proto.double_diff(i);
}
} else if (proto.diff_size() > 0) {
CHECK_EQ(count_, proto.diff_size());
Dtype* diff_vec = mutable_cpu_diff();
for (int i = 0; i < count_; ++i) {
diff_vec[i] = proto.diff(i);
}
}
}

template <typename Dtype>
void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
template <>
void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
proto->clear_shape();
for (int i = 0; i < shape_.size(); ++i) {
proto->mutable_shape()->add_dim(shape_[i]);
}
proto->clear_double_data();
proto->clear_double_diff();
const double* data_vec = cpu_data();
for (int i = 0; i < count_; ++i) {
proto->add_double_data(data_vec[i]);
}
if (write_diff) {
const double* diff_vec = cpu_diff();
for (int i = 0; i < count_; ++i) {
proto->add_double_diff(diff_vec[i]);
}
}
}

template <>
void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const {
proto->clear_shape();
for (int i = 0; i < shape_.size(); ++i) {
proto->mutable_shape()->add_dim(shape_[i]);
}
proto->clear_data();
proto->clear_diff();
const Dtype* data_vec = cpu_data();
const float* data_vec = cpu_data();
for (int i = 0; i < count_; ++i) {
proto->add_data(data_vec[i]);
}
if (write_diff) {
const Dtype* diff_vec = cpu_diff();
const float* diff_vec = cpu_diff();
for (int i = 0; i < count_; ++i) {
proto->add_diff(diff_vec[i]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/hdf5_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/hdf5.hpp"

namespace caffe {

Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/hdf5_output_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {
Expand Down
1 change: 0 additions & 1 deletion src/caffe/layers/hdf5_output_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"

namespace caffe {
Expand Down
Loading