Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test solvers on fixed hdf5 data #2887

Merged
merged 2 commits into from
Aug 9, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/caffe/test/test_data/generate_sample_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
Generate data used in the HDF5DataLayer test.
Generate data used in the HDF5DataLayer and GradientBasedSolver tests.
"""
import os
import numpy as np
import h5py

script_dir = os.path.dirname(os.path.abspath(__file__))

# Generate HDF5DataLayer sample_data.h5

num_cols = 8
num_rows = 10
height = 6
Expand Down Expand Up @@ -51,3 +53,27 @@
with open(script_dir + '/sample_data_list.txt', 'w') as f:
f.write(script_dir + '/sample_data.h5\n')
f.write(script_dir + '/sample_data_2_gzip.h5\n')

# Generate GradientBasedSolver solver_data.h5

num_cols = 3
num_rows = 8
height = 10
width = 10

data = np.random.randn(num_rows, num_cols, height, width)
data = data.reshape(num_rows, num_cols, height, width)
data = data.astype('float32')

targets = np.random.randn(num_rows, 1)
targets = targets.astype('float32')

print data
print targets

with h5py.File(script_dir + '/solver_data.h5', 'w') as f:
f['data'] = data
f['targets'] = targets

with open(script_dir + '/solver_data_list.txt', 'w') as f:
f.write(script_dir + '/solver_data.h5\n')
Binary file added src/caffe/test/test_data/solver_data.h5
Binary file not shown.
1 change: 1 addition & 0 deletions src/caffe/test/test_data/solver_data_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
src/caffe/test/test_data/solver_data.h5
40 changes: 17 additions & 23 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,26 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
protected:
GradientBasedSolverTest() :
seed_(1701), num_(4), channels_(3), height_(10), width_(10),
constant_data_(false), share_(false) {}
share_(false) {
input_file_ = new string(
CMAKE_SOURCE_DIR "caffe/test/test_data/solver_data_list.txt" CMAKE_EXT);
}
~GradientBasedSolverTest() {
delete input_file_;
}

string snapshot_prefix_;
shared_ptr<SGDSolver<Dtype> > solver_;
int seed_;
// Dimensions are determined by generate_sample_data.py
// TODO this is brittle and the hdf5 file should be checked instead.
int num_, channels_, height_, width_;
bool constant_data_, share_;
bool share_;
Dtype delta_; // Stability constant for AdaGrad.

// Test data: check out generate_sample_data.py in the same directory.
string* input_file_;

virtual SolverParameter_SolverType solver_type() = 0;
virtual void InitSolver(const SolverParameter& param) = 0;

Expand Down Expand Up @@ -71,25 +82,10 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
" name: 'TestNetwork' "
" layer { "
" name: 'data' "
" type: 'DummyData' "
" dummy_data_param { "
" num: " << num_ / iter_size << " "
" channels: " << channels_ << " "
" height: " << height_ << " "
" width: " << width_ << " "
" channels: 1 "
" height: 1 "
" width: 1 "
" data_filler { "
" type: '" << string(constant_data_ ? "constant" : "gaussian")
<< "' "
" std: 1.0 "
" value: 1.0 "
" } "
" data_filler { "
" type: 'gaussian' "
" std: 1.0 "
" } "
" type: 'HDF5Data' "
" hdf5_data_param { "
" source: '" << *(this->input_file_) << "' "
" batch_size: " << num_ / iter_size << " "
" } "
" top: 'data' "
" top: 'targets' "
Expand Down Expand Up @@ -180,7 +176,6 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
}
Caffe::set_random_seed(this->seed_);
this->InitSolverFromProtoString(proto.str());
Caffe::set_random_seed(this->seed_);
if (from_snapshot != NULL) {
this->solver_->Restore(from_snapshot);
vector<Blob<Dtype>*> empty_bottom_vec;
Expand Down Expand Up @@ -355,7 +350,6 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
const Dtype kMomentum, const int kNumIters, const int kIterSize) {
const double kPrecision = 1e-2;
const double kMinPrecision = 1e-7;
constant_data_ = true;
// Solve without accumulation and save parameters.
this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
kNumIters);
Expand Down