From 1856bb2d89ecbabe84cbf0a647de7a2ab6c29207 Mon Sep 17 00:00:00 2001 From: Sergio Guadarrama Date: Fri, 16 Jan 2015 13:18:50 -0800 Subject: [PATCH 1/4] add db wrappers --- include/caffe/util/db.hpp | 190 ++++++++++++++++++++++++++++++++++++++ src/caffe/util/db.cpp | 84 +++++++++++++++++ 2 files changed, 274 insertions(+) create mode 100644 include/caffe/util/db.hpp create mode 100644 src/caffe/util/db.cpp diff --git a/include/caffe/util/db.hpp b/include/caffe/util/db.hpp new file mode 100644 index 00000000000..afdb8d2c4f8 --- /dev/null +++ b/include/caffe/util/db.hpp @@ -0,0 +1,190 @@ +#ifndef CAFFE_UTIL_DB_HPP +#define CAFFE_UTIL_DB_HPP + +#include + +#include "leveldb/db.h" +#include "leveldb/write_batch.h" +#include "lmdb.h" + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { namespace db { + +enum Mode { READ, WRITE, NEW }; + +class Cursor { + public: + Cursor() { } + virtual ~Cursor() { } + virtual void SeekToFirst() = 0; + virtual void Next() = 0; + virtual string key() = 0; + virtual string value() = 0; + virtual bool valid() = 0; + + DISABLE_COPY_AND_ASSIGN(Cursor); +}; + +class Transaction { + public: + Transaction() { } + virtual ~Transaction() { } + virtual void Put(const string& key, const string& value) = 0; + virtual void Commit() = 0; + + DISABLE_COPY_AND_ASSIGN(Transaction); +}; + +class DB { + public: + DB() { } + virtual ~DB() { } + virtual void Open(const string& source, Mode mode) = 0; + virtual void Close() = 0; + virtual Cursor* NewCursor() = 0; + virtual Transaction* NewTransaction() = 0; + + DISABLE_COPY_AND_ASSIGN(DB); +}; + +class LevelDBCursor : public Cursor { + public: + explicit LevelDBCursor(leveldb::Iterator* iter) + : iter_(iter) { SeekToFirst(); } + ~LevelDBCursor() { delete iter_; } + virtual void SeekToFirst() { iter_->SeekToFirst(); } + virtual void Next() { iter_->Next(); } + virtual string key() { return iter_->key().ToString(); } + virtual string value() { return iter_->value().ToString(); } + virtual bool valid() { return iter_->Valid(); } + + private: + leveldb::Iterator* iter_; +}; + +class LevelDBTransaction : public Transaction { + public: + explicit LevelDBTransaction(leveldb::DB* db) : db_(db) { CHECK_NOTNULL(db_); } + virtual void Put(const string& key, const string& value) { + batch_.Put(key, value); + } + virtual void Commit() { + leveldb::Status status = db_->Write(leveldb::WriteOptions(), &batch_); + CHECK(status.ok()) << "Failed to write batch to leveldb " + << std::endl << status.ToString(); + } + + private: + leveldb::DB* db_; + leveldb::WriteBatch batch_; + + DISABLE_COPY_AND_ASSIGN(LevelDBTransaction); +}; + +class LevelDB : public DB { + public: + LevelDB() : db_(NULL) { } + virtual ~LevelDB() { Close(); } + virtual void Open(const string& source, Mode mode); + virtual void Close() { + if (db_ != NULL) { + delete db_; + db_ = NULL; + } + } + virtual LevelDBCursor* NewCursor() { + return new LevelDBCursor(db_->NewIterator(leveldb::ReadOptions())); + } + virtual LevelDBTransaction* NewTransaction() { + return new LevelDBTransaction(db_); + } + + private: + leveldb::DB* db_; +}; + +inline void MDB_CHECK(int mdb_status) { + CHECK_EQ(mdb_status, MDB_SUCCESS) << mdb_strerror(mdb_status); +} + +class LMDBCursor : public Cursor { + public: + explicit LMDBCursor(MDB_txn* mdb_txn, MDB_cursor* mdb_cursor) + : mdb_txn_(mdb_txn), mdb_cursor_(mdb_cursor), valid_(false) { + SeekToFirst(); + } + virtual ~LMDBCursor() { + mdb_cursor_close(mdb_cursor_); + mdb_txn_abort(mdb_txn_); + } + virtual void SeekToFirst() { Seek(MDB_FIRST); } + virtual void Next() { Seek(MDB_NEXT); } + virtual string key() { + return string(static_cast(mdb_key_.mv_data), mdb_key_.mv_size); + } + virtual string value() { + return string(static_cast(mdb_value_.mv_data), + mdb_value_.mv_size); + } + virtual bool valid() { return valid_; } + + private: + void Seek(MDB_cursor_op op) { + int mdb_status = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, op); + if (mdb_status == MDB_NOTFOUND) { + valid_ = false; + } else { + MDB_CHECK(mdb_status); + valid_ = true; + } + } + + MDB_txn* mdb_txn_; + MDB_cursor* mdb_cursor_; + MDB_val mdb_key_, mdb_value_; + bool valid_; +}; + +class LMDBTransaction : public Transaction { + public: + explicit LMDBTransaction(MDB_dbi* mdb_dbi, MDB_txn* mdb_txn) + : mdb_dbi_(mdb_dbi), mdb_txn_(mdb_txn) { } + virtual void Put(const string& key, const string& value); + virtual void Commit() { MDB_CHECK(mdb_txn_commit(mdb_txn_)); } + + private: + MDB_dbi* mdb_dbi_; + MDB_txn* mdb_txn_; + + DISABLE_COPY_AND_ASSIGN(LMDBTransaction); +}; + +class LMDB : public DB { + public: + LMDB() : mdb_env_(NULL) { } + virtual ~LMDB() { Close(); } + virtual void Open(const string& source, Mode mode); + virtual void Close() { + if (mdb_env_ != NULL) { + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_env_close(mdb_env_); + mdb_env_ = NULL; + } + } + virtual LMDBCursor* NewCursor(); + virtual LMDBTransaction* NewTransaction(); + + private: + MDB_env* mdb_env_; + MDB_dbi mdb_dbi_; +}; + +DB* GetDB(DataParameter::DB backend); +DB* GetDB(const string& backend); + +} // namespace db +} // namespace caffe + +#endif // CAFFE_UTIL_DB_HPP diff --git a/src/caffe/util/db.cpp b/src/caffe/util/db.cpp new file mode 100644 index 00000000000..7f7018107ec --- /dev/null +++ b/src/caffe/util/db.cpp @@ -0,0 +1,84 @@ +#include "caffe/util/db.hpp" + +#include +#include + +namespace caffe { namespace db { + +const size_t LMDB_MAP_SIZE = 1099511627776; // 1 TB + +void LevelDB::Open(const string& source, Mode mode) { + leveldb::Options options; + options.block_size = 65536; + options.write_buffer_size = 268435456; + options.max_open_files = 100; + options.error_if_exists = mode == NEW; + options.create_if_missing = mode != READ; + leveldb::Status status = leveldb::DB::Open(options, source, &db_); + CHECK(status.ok()) << "Failed to open leveldb " << source + << std::endl << status.ToString(); + LOG(INFO) << "Opened leveldb " << source; +} + +void LMDB::Open(const string& source, Mode mode) { + MDB_CHECK(mdb_env_create(&mdb_env_)); + MDB_CHECK(mdb_env_set_mapsize(mdb_env_, LMDB_MAP_SIZE)); + if (mode == NEW) { + CHECK_EQ(mkdir(source.c_str(), 0744), 0) << "mkdir " << source << "failed"; + } + int flags = 0; + if (mode == READ) { + flags = MDB_RDONLY | MDB_NOTLS; + } + MDB_CHECK(mdb_env_open(mdb_env_, source.c_str(), flags, 0664)); + LOG(INFO) << "Opened lmdb " << source; +} + +LMDBCursor* LMDB::NewCursor() { + MDB_txn* mdb_txn; + MDB_cursor* mdb_cursor; + MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, MDB_RDONLY, &mdb_txn)); + MDB_CHECK(mdb_dbi_open(mdb_txn, NULL, 0, &mdb_dbi_)); + MDB_CHECK(mdb_cursor_open(mdb_txn, mdb_dbi_, &mdb_cursor)); + return new LMDBCursor(mdb_txn, mdb_cursor); +} + +LMDBTransaction* LMDB::NewTransaction() { + MDB_txn* mdb_txn; + MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, 0, &mdb_txn)); + MDB_CHECK(mdb_dbi_open(mdb_txn, NULL, 0, &mdb_dbi_)); + return new LMDBTransaction(&mdb_dbi_, mdb_txn); +} + +void LMDBTransaction::Put(const string& key, const string& value) { + MDB_val mdb_key, mdb_value; + mdb_key.mv_data = const_cast(key.data()); + mdb_key.mv_size = key.size(); + mdb_value.mv_data = const_cast(value.data()); + mdb_value.mv_size = value.size(); + MDB_CHECK(mdb_put(mdb_txn_, *mdb_dbi_, &mdb_key, &mdb_value, 0)); +} + +DB* GetDB(DataParameter::DB backend) { + switch (backend) { + case DataParameter_DB_LEVELDB: + return new LevelDB(); + case DataParameter_DB_LMDB: + return new LMDB(); + default: + LOG(FATAL) << "Unknown database backend"; + } +} + +DB* GetDB(const string& backend) { + if (backend == "leveldb") { + return new LevelDB(); + } else if (backend == "lmdb") { + return new LMDB(); + } else { + LOG(FATAL) << "Unknown database backend"; + } +} + +} // namespace db +} // namespace caffe From 7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2 Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Fri, 16 Jan 2015 13:20:15 -0800 Subject: [PATCH 2/4] use db wrappers --- examples/cifar10/convert_cifar_data.cpp | 41 ++++++++++++----------- include/caffe/data_layers.hpp | 6 ++-- src/caffe/layers/data_layer.cpp | 34 ++++++++----------- src/caffe/test/test_data_layer.cpp | 27 ++++++++------- tools/compute_image_mean.cpp | 44 +++++++++++-------------- tools/convert_imageset.cpp | 38 ++++++++++----------- tools/extract_features.cpp | 31 +++++++++-------- 7 files changed, 111 insertions(+), 110 deletions(-) diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index 9eecc74c989..f4c42e4d2e7 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -9,19 +9,18 @@ #include // NOLINT(readability/streams) #include +#include "boost/scoped_ptr.hpp" #include "glog/logging.h" #include "google/protobuf/text_format.h" #include "stdint.h" -#include "caffe/dataset_factory.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" -using std::string; - -using caffe::Dataset; -using caffe::DatasetFactory; using caffe::Datum; -using caffe::shared_ptr; +using boost::scoped_ptr; +using std::string; +namespace db = caffe::db; const int kCIFARSize = 32; const int kCIFARImageNBytes = 3072; @@ -38,10 +37,9 @@ void read_image(std::ifstream* file, int* label, char* buffer) { void convert_dataset(const string& input_folder, const string& output_folder, const string& db_type) { - shared_ptr > train_dataset = - DatasetFactory(db_type); - CHECK(train_dataset->open(output_folder + "/cifar10_train_" + db_type, - Dataset::New)); + scoped_ptr train_db(db::GetDB(db_type)); + train_db->Open(output_folder + "/cifar10_train_" + db_type, db::NEW); + scoped_ptr txn(train_db->NewTransaction()); // Data buffer int label; char str_buffer[kCIFARImageNBytes]; @@ -64,17 +62,18 @@ void convert_dataset(const string& input_folder, const string& output_folder, datum.set_data(str_buffer, kCIFARImageNBytes); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); - CHECK(train_dataset->put(string(str_buffer, length), datum)); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put(string(str_buffer, length), out); } } - CHECK(train_dataset->commit()); - train_dataset->close(); + txn->Commit(); + train_db->Close(); LOG(INFO) << "Writing Testing data"; - shared_ptr > test_dataset = - DatasetFactory(db_type); - CHECK(test_dataset->open(output_folder + "/cifar10_test_" + db_type, - Dataset::New)); + scoped_ptr test_db(db::GetDB(db_type)); + test_db->Open(output_folder + "/cifar10_test_" + db_type, db::NEW); + txn.reset(test_db->NewTransaction()); // Open files std::ifstream data_file((input_folder + "/test_batch.bin").c_str(), std::ios::in | std::ios::binary); @@ -84,10 +83,12 @@ void convert_dataset(const string& input_folder, const string& output_folder, datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); - CHECK(test_dataset->put(string(str_buffer, length), datum)); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put(string(str_buffer, length), out); } - CHECK(test_dataset->commit()); - test_dataset->close(); + txn->Commit(); + test_db->Close(); } int main(int argc, char** argv) { diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 34b9b30aa3e..e9c83856a99 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -11,11 +11,11 @@ #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/data_transformer.hpp" -#include "caffe/dataset.hpp" #include "caffe/filler.hpp" #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" namespace caffe { @@ -100,8 +100,8 @@ class DataLayer : public BasePrefetchingDataLayer { protected: virtual void InternalThreadEntry(); - shared_ptr > dataset_; - Dataset::const_iterator iter_; + shared_ptr db_; + shared_ptr cursor_; }; /** diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index a5030899cee..96964566630 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -7,7 +7,6 @@ #include "caffe/common.hpp" #include "caffe/data_layers.hpp" -#include "caffe/dataset_factory.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/benchmark.hpp" @@ -20,35 +19,28 @@ namespace caffe { template DataLayer::~DataLayer() { this->JoinPrefetchThread(); - // clean up the dataset resources - dataset_->close(); } template void DataLayer::DataLayerSetUp(const vector*>& bottom, const vector*>& top) { // Initialize DB - dataset_ = DatasetFactory( - this->layer_param_.data_param().backend()); - const string& source = this->layer_param_.data_param().source(); - LOG(INFO) << "Opening dataset " << source; - CHECK(dataset_->open(source, Dataset::ReadOnly)); - iter_ = dataset_->begin(); + db_.reset(db::GetDB(this->layer_param_.data_param().backend())); + db_->Open(this->layer_param_.data_param().source(), db::READ); + cursor_.reset(db_->NewCursor()); - // Check if we would need to randomly skip a few data points + // Check if we should randomly skip a few data points if (this->layer_param_.data_param().rand_skip()) { unsigned int skip = caffe_rng_rand() % this->layer_param_.data_param().rand_skip(); LOG(INFO) << "Skipping first " << skip << " data points."; while (skip-- > 0) { - if (++iter_ == dataset_->end()) { - iter_ = dataset_->begin(); - } + cursor_->Next(); } } // Read a data point, and use it to initialize the top blob. - CHECK(iter_ != dataset_->end()); - Datum datum = iter_->value; + Datum datum; + datum.ParseFromString(cursor_->value()); if (DecodeDatum(&datum)) { LOG(INFO) << "Decoding Datum"; @@ -101,8 +93,8 @@ void DataLayer::InternalThreadEntry() { for (int item_id = 0; item_id < batch_size; ++item_id) { timer.Start(); // get a blob - CHECK(iter_ != dataset_->end()); - const Datum& datum = iter_->value; + Datum datum; + datum.ParseFromString(cursor_->value()); cv::Mat cv_img; if (datum.encoded()) { @@ -124,9 +116,10 @@ void DataLayer::InternalThreadEntry() { } trans_time += timer.MicroSeconds(); // go to the next iter - ++iter_; - if (iter_ == dataset_->end()) { - iter_ = dataset_->begin(); + cursor_->Next(); + if (!cursor_->valid()) { + DLOG(INFO) << "Restarting data prefetching from start."; + cursor_->SeekToFirst(); } } batch_timer.Stop(); @@ -137,4 +130,5 @@ void DataLayer::InternalThreadEntry() { INSTANTIATE_CLASS(DataLayer); REGISTER_LAYER_CLASS(DATA, DataLayer); + } // namespace caffe diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index 32f5d41e2f2..8cc31b24167 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -1,20 +1,23 @@ #include #include +#include "boost/scoped_ptr.hpp" #include "gtest/gtest.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" -#include "caffe/dataset_factory.hpp" +#include "caffe/data_layers.hpp" #include "caffe/filler.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" #include "caffe/util/io.hpp" -#include "caffe/vision_layers.hpp" #include "caffe/test/test_caffe_main.hpp" namespace caffe { +using boost::scoped_ptr; + template class DataLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; @@ -33,15 +36,15 @@ class DataLayerTest : public MultiDeviceTest { blob_top_vec_.push_back(blob_top_label_); } - // Fill the LevelDB with data: if unique_pixels, each pixel is unique but + // Fill the DB with data: if unique_pixels, each pixel is unique but // all images are the same; else each image is unique but all pixels within // an image are the same. void Fill(const bool unique_pixels, DataParameter_DB backend) { backend_ = backend; LOG(INFO) << "Using temporary dataset " << *filename_; - shared_ptr > dataset = - DatasetFactory(backend_); - CHECK(dataset->open(*filename_, Dataset::New)); + scoped_ptr db(db::GetDB(backend)); + db->Open(*filename_, db::NEW); + scoped_ptr txn(db->NewTransaction()); for (int i = 0; i < 5; ++i) { Datum datum; datum.set_label(i); @@ -55,10 +58,12 @@ class DataLayerTest : public MultiDeviceTest { } stringstream ss; ss << i; - CHECK(dataset->put(ss.str(), datum)); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put(ss.str(), out); } - CHECK(dataset->commit()); - dataset->close(); + txn->Commit(); + db->Close(); } void TestRead() { @@ -183,7 +188,7 @@ class DataLayerTest : public MultiDeviceTest { } crop_sequence.push_back(iter_crop_sequence); } - } // destroy 1st data layer and unlock the dataset + } // destroy 1st data layer and unlock the db // Get crop sequence after reseeding Caffe with 1701. // Check that the sequence is the same as the original. @@ -238,7 +243,7 @@ class DataLayerTest : public MultiDeviceTest { } crop_sequence.push_back(iter_crop_sequence); } - } // destroy 1st data layer and unlock the dataset + } // destroy 1st data layer and unlock the db // Get crop sequence continuing from previous Caffe RNG state; reseed // srand with 1701. Check that the sequence differs from the original. diff --git a/tools/compute_image_mean.cpp b/tools/compute_image_mean.cpp index 358f57e38d6..dff63a09dca 100644 --- a/tools/compute_image_mean.cpp +++ b/tools/compute_image_mean.cpp @@ -1,24 +1,25 @@ -#include -#include #include - #include #include #include #include -#include "caffe/dataset_factory.hpp" +#include "boost/scoped_ptr.hpp" +#include "gflags/gflags.h" +#include "glog/logging.h" + #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" #include "caffe/util/io.hpp" -using caffe::Dataset; -using caffe::Datum; -using caffe::BlobProto; +using namespace caffe; // NOLINT(build/namespaces) + using std::max; using std::pair; +using boost::scoped_ptr; - -DEFINE_string(backend, "lmdb", "The backend for containing the images"); +DEFINE_string(backend, "lmdb", + "The backend {leveldb, lmdb} containing the images"); int main(int argc, char** argv) { ::google::InitGoogleLogging(argv[0]); @@ -28,7 +29,7 @@ int main(int argc, char** argv) { #endif gflags::SetUsageMessage("Compute the mean_image of a set of images given by" - " a leveldb/lmdb or a list of images\n" + " a leveldb/lmdb\n" "Usage:\n" " compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n"); @@ -39,19 +40,15 @@ int main(int argc, char** argv) { return 1; } - std::string db_backend = FLAGS_backend; - - caffe::shared_ptr > dataset = - caffe::DatasetFactory(db_backend); - - // Open db - CHECK(dataset->open(argv[1], Dataset::ReadOnly)); + scoped_ptr db(db::GetDB(FLAGS_backend)); + db->Open(argv[1], db::READ); + scoped_ptr cursor(db->NewCursor()); BlobProto sum_blob; int count = 0; // load first datum - Dataset::const_iterator iter = dataset->begin(); - Datum datum = iter->value; + Datum datum; + datum.ParseFromString(cursor->value()); if (DecodeDatum(&datum)) { LOG(INFO) << "Decoding Datum"; @@ -68,9 +65,9 @@ int main(int argc, char** argv) { sum_blob.add_data(0.); } LOG(INFO) << "Starting Iteration"; - for (Dataset::const_iterator iter = dataset->begin(); - iter != dataset->end(); ++iter) { - Datum datum = iter->value; + while (cursor->valid()) { + Datum datum; + datum.ParseFromString(cursor->value()); DecodeDatum(&datum); const std::string& data = datum.data(); @@ -94,6 +91,7 @@ int main(int argc, char** argv) { if (count % 10000 == 0) { LOG(INFO) << "Processed " << count << " files."; } + cursor->Next(); } if (count % 10000 != 0) { @@ -117,7 +115,5 @@ int main(int argc, char** argv) { } LOG(INFO) << "mean_value channel [" << c << "]:" << mean_values[c] / dim; } - // Clean up - dataset->close(); return 0; } diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 5d42e69c1ed..7fbf5b0514c 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -8,28 +8,31 @@ // subfolder1/file1.JPEG 7 // .... -#include -#include - #include #include // NOLINT(readability/streams) #include #include #include -#include "caffe/dataset_factory.hpp" +#include "boost/scoped_ptr.hpp" +#include "gflags/gflags.h" +#include "glog/logging.h" + #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" #include "caffe/util/io.hpp" #include "caffe/util/rng.hpp" using namespace caffe; // NOLINT(build/namespaces) using std::pair; +using boost::scoped_ptr; DEFINE_bool(gray, false, "When this option is on, treat images as grayscale ones"); DEFINE_bool(shuffle, false, "Randomly shuffle the order of images and their labels"); -DEFINE_string(backend, "lmdb", "The backend for storing the result"); +DEFINE_string(backend, "lmdb", + "The backend {lmdb, leveldb} for storing the result"); DEFINE_int32(resize_width, 0, "Width images are resized to"); DEFINE_int32(resize_height, 0, "Height images are resized to"); DEFINE_bool(check_size, false, @@ -75,9 +78,6 @@ int main(int argc, char** argv) { } LOG(INFO) << "A total of " << lines.size() << " images."; - const std::string& db_backend = FLAGS_backend; - const char* db_path = argv[3]; - if (encoded) { CHECK_EQ(FLAGS_resize_height, 0) << "With encoded don't resize images"; CHECK_EQ(FLAGS_resize_width, 0) << "With encoded don't resize images"; @@ -87,12 +87,10 @@ int main(int argc, char** argv) { int resize_height = std::max(0, FLAGS_resize_height); int resize_width = std::max(0, FLAGS_resize_width); - // Open new db - shared_ptr > dataset = - DatasetFactory(db_backend); - - // Open db - CHECK(dataset->open(db_path, Dataset::New)); + // Create new DB + scoped_ptr db(db::GetDB(FLAGS_backend)); + db->Open(argv[3], db::NEW); + scoped_ptr txn(db->NewTransaction()); // Storing to db std::string root_folder(argv[1]); @@ -128,19 +126,21 @@ int main(int argc, char** argv) { lines[line_id].first.c_str()); // Put in db - CHECK(dataset->put(string(key_cstr, length), datum)); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put(string(key_cstr, length), out); if (++count % 1000 == 0) { - // Commit txn - CHECK(dataset->commit()); + // Commit db + txn->Commit(); + txn.reset(db->NewTransaction()); LOG(ERROR) << "Processed " << count << " files."; } } // write the last batch if (count % 1000 != 0) { - CHECK(dataset->commit()); + txn->Commit(); LOG(ERROR) << "Processed " << count << " files."; } - dataset->close(); return 0; } diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index ddbce1075ed..c17b88dd048 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -7,19 +7,19 @@ #include "caffe/blob.hpp" #include "caffe/common.hpp" -#include "caffe/dataset_factory.hpp" #include "caffe/net.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" #include "caffe/util/io.hpp" #include "caffe/vision_layers.hpp" -using boost::shared_ptr; using caffe::Blob; using caffe::Caffe; -using caffe::Dataset; -using caffe::DatasetFactory; using caffe::Datum; using caffe::Net; +using boost::shared_ptr; +using std::string; +namespace db = caffe::db; template int feature_extraction_pipeline(int argc, char** argv); @@ -121,13 +121,15 @@ int feature_extraction_pipeline(int argc, char** argv) { int num_mini_batches = atoi(argv[++arg_pos]); - std::vector > > feature_dbs; + std::vector > feature_dbs; + std::vector > txns; for (size_t i = 0; i < num_features; ++i) { LOG(INFO)<< "Opening dataset " << dataset_names[i]; - shared_ptr > dataset = - DatasetFactory(argv[++arg_pos]); - CHECK(dataset->open(dataset_names.at(i), Dataset::New)); - feature_dbs.push_back(dataset); + shared_ptr db(db::GetDB(argv[++arg_pos])); + db->Open(dataset_names.at(i), db::NEW); + feature_dbs.push_back(db); + shared_ptr txn(db->NewTransaction()); + txns.push_back(txn); } LOG(ERROR)<< "Extacting Features"; @@ -158,10 +160,13 @@ int feature_extraction_pipeline(int argc, char** argv) { } int length = snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); - CHECK(feature_dbs.at(i)->put(std::string(key_str, length), datum)); + string out; + CHECK(datum.SerializeToString(&out)); + txns.at(i)->Put(std::string(key_str, length), out); ++image_indices[i]; if (image_indices[i] % 1000 == 0) { - CHECK(feature_dbs.at(i)->commit()); + txns.at(i)->Commit(); + txns.at(i).reset(feature_dbs.at(i)->NewTransaction()); LOG(ERROR)<< "Extracted features of " << image_indices[i] << " query images for feature blob " << blob_names[i]; } @@ -171,11 +176,11 @@ int feature_extraction_pipeline(int argc, char** argv) { // write the last batch for (int i = 0; i < num_features; ++i) { if (image_indices[i] % 1000 != 0) { - CHECK(feature_dbs.at(i)->commit()); + txns.at(i)->Commit(); } LOG(ERROR)<< "Extracted features of " << image_indices[i] << " query images for feature blob " << blob_names[i]; - feature_dbs.at(i)->close(); + feature_dbs.at(i)->Close(); } LOG(ERROR)<< "Successfully extracted the features!"; From 88e3bc88b4ca09d51dd551b62e64e3c149cd4e0c Mon Sep 17 00:00:00 2001 From: Sergio Guadarrama Date: Fri, 16 Jan 2015 13:20:26 -0800 Subject: [PATCH 3/4] test db wrappers --- src/caffe/test/test_db.cpp | 134 +++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/caffe/test/test_db.cpp diff --git a/src/caffe/test/test_db.cpp b/src/caffe/test/test_db.cpp new file mode 100644 index 00000000000..5b2ac230a0b --- /dev/null +++ b/src/caffe/test/test_db.cpp @@ -0,0 +1,134 @@ +#include + +#include "boost/scoped_ptr.hpp" +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" +#include "caffe/util/io.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +using boost::scoped_ptr; + +template +class DBTest : public ::testing::Test { + protected: + DBTest() + : backend_(TypeParam::backend), + root_images_(string(EXAMPLES_SOURCE_DIR) + string("images/")) {} + + virtual void SetUp() { + MakeTempDir(&source_); + source_ += "/db"; + string keys[] = {"cat.jpg", "fish-bike.jpg"}; + LOG(INFO) << "Using temporary db " << source_; + scoped_ptr db(db::GetDB(TypeParam::backend)); + db->Open(this->source_, db::NEW); + scoped_ptr txn(db->NewTransaction()); + for (int i = 0; i < 2; ++i) { + Datum datum; + ReadImageToDatum(root_images_ + keys[i], i, &datum); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put(keys[i], out); + } + txn->Commit(); + } + + virtual ~DBTest() { } + + DataParameter_DB backend_; + string source_; + string root_images_; +}; + +struct TypeLevelDB { + static DataParameter_DB backend; +}; +DataParameter_DB TypeLevelDB::backend = DataParameter_DB_LEVELDB; + +struct TypeLMDB { + static DataParameter_DB backend; +}; +DataParameter_DB TypeLMDB::backend = DataParameter_DB_LMDB; + +// typedef ::testing::Types TestTypes; +typedef ::testing::Types TestTypes; + +TYPED_TEST_CASE(DBTest, TestTypes); + +TYPED_TEST(DBTest, TestGetDB) { + scoped_ptr db(db::GetDB(TypeParam::backend)); +} + +TYPED_TEST(DBTest, TestNext) { + scoped_ptr db(db::GetDB(TypeParam::backend)); + db->Open(this->source_, db::READ); + scoped_ptr cursor(db->NewCursor()); + EXPECT_TRUE(cursor->valid()); + cursor->Next(); + EXPECT_TRUE(cursor->valid()); + cursor->Next(); + EXPECT_FALSE(cursor->valid()); +} + +TYPED_TEST(DBTest, TestSeekToFirst) { + scoped_ptr db(db::GetDB(TypeParam::backend)); + db->Open(this->source_, db::READ); + scoped_ptr cursor(db->NewCursor()); + cursor->Next(); + cursor->SeekToFirst(); + EXPECT_TRUE(cursor->valid()); + string key = cursor->key(); + Datum datum; + datum.ParseFromString(cursor->value()); + EXPECT_EQ(key, "cat.jpg"); + EXPECT_EQ(datum.channels(), 3); + EXPECT_EQ(datum.height(), 360); + EXPECT_EQ(datum.width(), 480); +} + +TYPED_TEST(DBTest, TestKeyValue) { + scoped_ptr db(db::GetDB(TypeParam::backend)); + db->Open(this->source_, db::READ); + scoped_ptr cursor(db->NewCursor()); + EXPECT_TRUE(cursor->valid()); + string key = cursor->key(); + Datum datum; + datum.ParseFromString(cursor->value()); + EXPECT_EQ(key, "cat.jpg"); + EXPECT_EQ(datum.channels(), 3); + EXPECT_EQ(datum.height(), 360); + EXPECT_EQ(datum.width(), 480); + cursor->Next(); + EXPECT_TRUE(cursor->valid()); + key = cursor->key(); + datum.ParseFromString(cursor->value()); + EXPECT_EQ(key, "fish-bike.jpg"); + EXPECT_EQ(datum.channels(), 3); + EXPECT_EQ(datum.height(), 323); + EXPECT_EQ(datum.width(), 481); + cursor->Next(); + EXPECT_FALSE(cursor->valid()); +} + +TYPED_TEST(DBTest, TestWrite) { + scoped_ptr db(db::GetDB(TypeParam::backend)); + db->Open(this->source_, db::WRITE); + scoped_ptr txn(db->NewTransaction()); + Datum datum; + ReadFileToDatum(this->root_images_ + "cat.jpg", 0, &datum); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put("cat.jpg", out); + ReadFileToDatum(this->root_images_ + "fish-bike.jpg", 1, &datum); + CHECK(datum.SerializeToString(&out)); + txn->Put("fish-bike.jpg", out); + txn->Commit(); +} + +} // namespace caffe From 3b88e359f7f57a62608d42e2e6786a25eab4ecbb Mon Sep 17 00:00:00 2001 From: Jonathan L Long Date: Fri, 16 Jan 2015 13:25:00 -0800 Subject: [PATCH 4/4] gut dataset wrappers --- include/caffe/dataset.hpp | 241 --------- include/caffe/dataset_factory.hpp | 20 - include/caffe/leveldb_dataset.hpp | 90 ---- include/caffe/lmdb_dataset.hpp | 95 ---- src/caffe/dataset_factory.cpp | 50 -- src/caffe/leveldb_dataset.cpp | 265 ---------- src/caffe/lmdb_dataset.cpp | 366 -------------- src/caffe/test/test_dataset.cpp | 794 ------------------------------ 8 files changed, 1921 deletions(-) delete mode 100644 include/caffe/dataset.hpp delete mode 100644 include/caffe/dataset_factory.hpp delete mode 100644 include/caffe/leveldb_dataset.hpp delete mode 100644 include/caffe/lmdb_dataset.hpp delete mode 100644 src/caffe/dataset_factory.cpp delete mode 100644 src/caffe/leveldb_dataset.cpp delete mode 100644 src/caffe/lmdb_dataset.cpp delete mode 100644 src/caffe/test/test_dataset.cpp diff --git a/include/caffe/dataset.hpp b/include/caffe/dataset.hpp deleted file mode 100644 index 1dd8458cd74..00000000000 --- a/include/caffe/dataset.hpp +++ /dev/null @@ -1,241 +0,0 @@ -#ifndef CAFFE_DATASET_H_ -#define CAFFE_DATASET_H_ - -#include - -#include -#include -#include -#include -#include - -#include "caffe/common.hpp" -#include "caffe/proto/caffe.pb.h" - -namespace caffe { - -namespace dataset_internal { - -using google::protobuf::Message; - -template -struct static_assertion {}; -template<> -struct static_assertion { - enum { - DEFAULT_CODER_NOT_AVAILABLE - }; -}; - -template -struct DefaultCoder { - using static_assertion::DEFAULT_CODER_NOT_AVAILABLE; - static bool serialize(const T& obj, string* serialized); - static bool serialize(const T& obj, vector* serialized); - static bool deserialize(const string& serialized, T* obj); - static bool deserialize(const char* data, size_t size, T* obj); -}; - -template <> -struct DefaultCoder { - static bool serialize(const Message& obj, string* serialized) { - return obj.SerializeToString(serialized); - } - - static bool serialize(const Message& obj, vector* serialized) { - serialized->resize(obj.ByteSize()); - return obj.SerializeWithCachedSizesToArray( - reinterpret_cast(serialized->data())); - } - - static bool deserialize(const string& serialized, Message* obj) { - return obj->ParseFromString(serialized); - } - - static bool deserialize(const char* data, size_t size, Message* obj) { - return obj->ParseFromArray(data, size); - } -}; - -template <> -struct DefaultCoder : public DefaultCoder { }; - -template <> -struct DefaultCoder { - static bool serialize(string obj, string* serialized) { - *serialized = obj; - return true; - } - - static bool serialize(const string& obj, vector* serialized) { - vector temp(obj.data(), obj.data() + obj.size()); - serialized->swap(temp); - return true; - } - - static bool deserialize(const string& serialized, string* obj) { - *obj = serialized; - return true; - } - - static bool deserialize(const char* data, size_t size, string* obj) { - string temp_string(data, size); - obj->swap(temp_string); - return true; - } -}; - -template <> -struct DefaultCoder > { - static bool serialize(vector obj, string* serialized) { - string tmp(obj.data(), obj.size()); - serialized->swap(tmp); - return true; - } - - static bool serialize(const vector& obj, vector* serialized) { - *serialized = obj; - return true; - } - - static bool deserialize(const string& serialized, vector* obj) { - vector tmp(serialized.data(), serialized.data() + serialized.size()); - obj->swap(tmp); - return true; - } - - static bool deserialize(const char* data, size_t size, vector* obj) { - vector tmp(data, data + size); - obj->swap(tmp); - return true; - } -}; - -} // namespace dataset_internal - -template , - typename VCoder = dataset_internal::DefaultCoder > -class Dataset { - public: - enum Mode { - New, - ReadWrite, - ReadOnly - }; - - typedef K key_type; - typedef V value_type; - - struct KV { - K key; - V value; - }; - - virtual bool open(const string& filename, Mode mode) = 0; - virtual bool put(const K& key, const V& value) = 0; - virtual bool get(const K& key, V* value) = 0; - virtual bool first_key(K* key) = 0; - virtual bool last_key(K* key) = 0; - virtual bool commit() = 0; - virtual void close() = 0; - - virtual void keys(vector* keys) = 0; - - Dataset() { } - virtual ~Dataset() { } - - class iterator; - typedef iterator const_iterator; - - virtual const_iterator begin() const = 0; - virtual const_iterator cbegin() const = 0; - virtual const_iterator end() const = 0; - virtual const_iterator cend() const = 0; - - protected: - class DatasetState; - - public: - class iterator : public std::iterator { - public: - typedef KV T; - typedef T value_type; - typedef T& reference_type; - typedef T* pointer_type; - - iterator() - : parent_(NULL) { } - - iterator(const Dataset* parent, shared_ptr state) - : parent_(parent), - state_(state) { } - - iterator(const iterator& other) - : parent_(other.parent_), - state_(other.state_ ? other.state_->clone() - : shared_ptr()) { } - - iterator& operator=(iterator copy) { - copy.swap(*this); - return *this; - } - - void swap(iterator& other) throw() { - std::swap(this->parent_, other.parent_); - std::swap(this->state_, other.state_); - } - - bool operator==(const iterator& other) const { - return parent_->equal(state_, other.state_); - } - - bool operator!=(const iterator& other) const { - return !(*this == other); - } - - iterator& operator++() { - parent_->increment(&state_); - return *this; - } - iterator operator++(int) { - iterator copy(*this); - parent_->increment(&state_); - return copy; - } - - reference_type operator*() const { - return parent_->dereference(state_); - } - - pointer_type operator->() const { - return &parent_->dereference(state_); - } - - protected: - const Dataset* parent_; - shared_ptr state_; - }; - - protected: - class DatasetState { - public: - virtual ~DatasetState() { } - virtual shared_ptr clone() = 0; - }; - - virtual bool equal(shared_ptr state1, - shared_ptr state2) const = 0; - virtual void increment(shared_ptr* state) const = 0; - virtual KV& dereference( - shared_ptr state) const = 0; -}; - -} // namespace caffe - -#define INSTANTIATE_DATASET(type) \ - template class type; \ - template class type >; \ - template class type; - -#endif // CAFFE_DATASET_H_ diff --git a/include/caffe/dataset_factory.hpp b/include/caffe/dataset_factory.hpp deleted file mode 100644 index 57db49bf524..00000000000 --- a/include/caffe/dataset_factory.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef CAFFE_DATASET_FACTORY_H_ -#define CAFFE_DATASET_FACTORY_H_ - -#include - -#include "caffe/common.hpp" -#include "caffe/dataset.hpp" -#include "caffe/proto/caffe.pb.h" - -namespace caffe { - -template -shared_ptr > DatasetFactory(const DataParameter_DB& type); - -template -shared_ptr > DatasetFactory(const string& type); - -} // namespace caffe - -#endif // CAFFE_DATASET_FACTORY_H_ diff --git a/include/caffe/leveldb_dataset.hpp b/include/caffe/leveldb_dataset.hpp deleted file mode 100644 index d58c181bb2b..00000000000 --- a/include/caffe/leveldb_dataset.hpp +++ /dev/null @@ -1,90 +0,0 @@ -#ifndef CAFFE_LEVELDB_DATASET_H_ -#define CAFFE_LEVELDB_DATASET_H_ - -#include -#include - -#include -#include -#include - -#include "caffe/common.hpp" -#include "caffe/dataset.hpp" - -namespace caffe { - -template , - typename VCoder = dataset_internal::DefaultCoder > -class LeveldbDataset : public Dataset { - public: - typedef Dataset Base; - typedef typename Base::key_type key_type; - typedef typename Base::value_type value_type; - typedef typename Base::DatasetState DatasetState; - typedef typename Base::Mode Mode; - typedef typename Base::const_iterator const_iterator; - typedef typename Base::KV KV; - - bool open(const string& filename, Mode mode); - bool put(const K& key, const V& value); - bool get(const K& key, V* value); - bool first_key(K* key); - bool last_key(K* key); - bool commit(); - void close(); - - void keys(vector* keys); - - const_iterator begin() const; - const_iterator cbegin() const; - const_iterator end() const; - const_iterator cend() const; - - protected: - class LeveldbState : public DatasetState { - public: - explicit LeveldbState(shared_ptr db, - shared_ptr iter) - : DatasetState(), - db_(db), - iter_(iter) { } - - ~LeveldbState() { - // This order is very important. - // Iterators must be destroyed before their associated DB - // is destroyed. - iter_.reset(); - db_.reset(); - } - - shared_ptr clone() { - shared_ptr new_iter; - - CHECK(iter_.get()); - new_iter.reset(db_->NewIterator(leveldb::ReadOptions())); - CHECK(iter_->Valid()); - new_iter->Seek(iter_->key()); - CHECK(new_iter->Valid()); - - return shared_ptr(new LeveldbState(db_, new_iter)); - } - - shared_ptr db_; - shared_ptr iter_; - KV kv_pair_; - }; - - bool equal(shared_ptr state1, - shared_ptr state2) const; - void increment(shared_ptr* state) const; - KV& dereference(shared_ptr state) const; - - shared_ptr db_; - shared_ptr batch_; - bool read_only_; -}; - -} // namespace caffe - -#endif // CAFFE_LEVELDB_DATASET_H_ diff --git a/include/caffe/lmdb_dataset.hpp b/include/caffe/lmdb_dataset.hpp deleted file mode 100644 index ac1e5ee25dd..00000000000 --- a/include/caffe/lmdb_dataset.hpp +++ /dev/null @@ -1,95 +0,0 @@ -#ifndef CAFFE_LMDB_DATASET_H_ -#define CAFFE_LMDB_DATASET_H_ - -#include -#include -#include - -#include "lmdb.h" - -#include "caffe/common.hpp" -#include "caffe/dataset.hpp" - -namespace caffe { - -template , - typename VCoder = dataset_internal::DefaultCoder > -class LmdbDataset : public Dataset { - public: - typedef Dataset Base; - typedef typename Base::key_type key_type; - typedef typename Base::value_type value_type; - typedef typename Base::DatasetState DatasetState; - typedef typename Base::Mode Mode; - typedef typename Base::const_iterator const_iterator; - typedef typename Base::KV KV; - - LmdbDataset() - : env_(NULL), - dbi_(0), - write_txn_(NULL), - read_txn_(NULL) { } - - bool open(const string& filename, Mode mode); - bool put(const K& key, const V& value); - bool get(const K& key, V* value); - bool first_key(K* key); - bool last_key(K* key); - bool commit(); - void close(); - - void keys(vector* keys); - - const_iterator begin() const; - const_iterator cbegin() const; - const_iterator end() const; - const_iterator cend() const; - - protected: - class LmdbState : public DatasetState { - public: - explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi) - : DatasetState(), - cursor_(cursor), - txn_(txn), - dbi_(dbi) { } - - shared_ptr clone() { - CHECK(cursor_); - - MDB_cursor* new_cursor; - int retval; - - retval = mdb_cursor_open(txn_, *dbi_, &new_cursor); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - MDB_val key; - MDB_val val; - retval = mdb_cursor_get(cursor_, &key, &val, MDB_GET_CURRENT); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - retval = mdb_cursor_get(new_cursor, &key, &val, MDB_SET); - CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval); - - return shared_ptr(new LmdbState(new_cursor, txn_, dbi_)); - } - - MDB_cursor* cursor_; - MDB_txn* txn_; - const MDB_dbi* dbi_; - KV kv_pair_; - }; - - bool equal(shared_ptr state1, - shared_ptr state2) const; - void increment(shared_ptr* state) const; - KV& dereference(shared_ptr state) const; - - MDB_env* env_; - MDB_dbi dbi_; - MDB_txn* write_txn_; - MDB_txn* read_txn_; -}; - -} // namespace caffe - -#endif // CAFFE_LMDB_DATASET_H_ diff --git a/src/caffe/dataset_factory.cpp b/src/caffe/dataset_factory.cpp deleted file mode 100644 index 3313de3c408..00000000000 --- a/src/caffe/dataset_factory.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include -#include -#include - -#include "caffe/dataset_factory.hpp" -#include "caffe/leveldb_dataset.hpp" -#include "caffe/lmdb_dataset.hpp" - -namespace caffe { - -template -shared_ptr > DatasetFactory(const DataParameter_DB& type) { - switch (type) { - case DataParameter_DB_LEVELDB: - return shared_ptr >(new LeveldbDataset()); - case DataParameter_DB_LMDB: - return shared_ptr >(new LmdbDataset()); - default: - LOG(FATAL) << "Unknown dataset type " << type; - return shared_ptr >(); - } -} - -template -shared_ptr > DatasetFactory(const string& type) { - if ("leveldb" == type) { - return DatasetFactory(DataParameter_DB_LEVELDB); - } else if ("lmdb" == type) { - return DatasetFactory(DataParameter_DB_LMDB); - } else { - LOG(FATAL) << "Unknown dataset type " << type; - return shared_ptr >(); - } -} - -#define REGISTER_DATASET(key_type, value_type) \ - template shared_ptr > \ - DatasetFactory(const string& type); \ - template shared_ptr > \ - DatasetFactory(const DataParameter_DB& type); \ - -REGISTER_DATASET(string, string); -REGISTER_DATASET(string, vector); -REGISTER_DATASET(string, Datum); - -#undef REGISTER_DATASET - -} // namespace caffe - - diff --git a/src/caffe/leveldb_dataset.cpp b/src/caffe/leveldb_dataset.cpp deleted file mode 100644 index 53df985721c..00000000000 --- a/src/caffe/leveldb_dataset.cpp +++ /dev/null @@ -1,265 +0,0 @@ -#include -#include -#include - -#include "caffe/caffe.hpp" -#include "caffe/leveldb_dataset.hpp" - -namespace caffe { - -template -bool LeveldbDataset::open(const string& filename, - Mode mode) { - DLOG(INFO) << "LevelDB: Open " << filename; - - leveldb::Options options; - switch (mode) { - case Base::New: - DLOG(INFO) << " mode NEW"; - options.error_if_exists = true; - options.create_if_missing = true; - read_only_ = false; - break; - case Base::ReadWrite: - DLOG(INFO) << " mode RW"; - options.error_if_exists = false; - options.create_if_missing = true; - read_only_ = false; - break; - case Base::ReadOnly: - DLOG(INFO) << " mode RO"; - options.error_if_exists = false; - options.create_if_missing = false; - read_only_ = true; - break; - default: - DLOG(FATAL) << "unknown mode " << mode; - } - options.write_buffer_size = 268435456; - options.max_open_files = 100; - - leveldb::DB* db; - - LOG(INFO) << "Opening leveldb " << filename; - leveldb::Status status = leveldb::DB::Open( - options, filename, &db); - db_.reset(db); - - if (!status.ok()) { - LOG(ERROR) << "Failed to open leveldb " << filename - << ". Is it already existing?"; - return false; - } - - batch_.reset(new leveldb::WriteBatch()); - return true; -} - -template -bool LeveldbDataset::put(const K& key, const V& value) { - DLOG(INFO) << "LevelDB: Put"; - - if (read_only_) { - LOG(ERROR) << "put can not be used on a dataset in ReadOnly mode"; - return false; - } - - CHECK_NOTNULL(batch_.get()); - - string serialized_key; - if (!KCoder::serialize(key, &serialized_key)) { - return false; - } - - string serialized_value; - if (!VCoder::serialize(value, &serialized_value)) { - return false; - } - - batch_->Put(serialized_key, serialized_value); - - return true; -} - -template -bool LeveldbDataset::get(const K& key, V* value) { - DLOG(INFO) << "LevelDB: Get"; - - string serialized_key; - if (!KCoder::serialize(key, &serialized_key)) { - return false; - } - - string serialized_value; - leveldb::Status status = - db_->Get(leveldb::ReadOptions(), serialized_key, &serialized_value); - - if (!status.ok()) { - LOG(ERROR) << "leveldb get failed"; - return false; - } - - if (!VCoder::deserialize(serialized_value, value)) { - return false; - } - - return true; -} - -template -bool LeveldbDataset::first_key(K* key) { - DLOG(INFO) << "LevelDB: First key"; - - CHECK_NOTNULL(db_.get()); - shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); - iter->SeekToFirst(); - CHECK(iter->Valid()); - const leveldb::Slice& key_slice = iter->key(); - return KCoder::deserialize(key_slice.data(), key_slice.size(), key); -} - -template -bool LeveldbDataset::last_key(K* key) { - DLOG(INFO) << "LevelDB: Last key"; - - CHECK_NOTNULL(db_.get()); - shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); - iter->SeekToLast(); - CHECK(iter->Valid()); - const leveldb::Slice& key_slice = iter->key(); - return KCoder::deserialize(key_slice.data(), key_slice.size(), key); -} - -template -bool LeveldbDataset::commit() { - DLOG(INFO) << "LevelDB: Commit"; - - if (read_only_) { - LOG(ERROR) << "commit can not be used on a dataset in ReadOnly mode"; - return false; - } - - CHECK_NOTNULL(db_.get()); - CHECK_NOTNULL(batch_.get()); - - leveldb::Status status = db_->Write(leveldb::WriteOptions(), batch_.get()); - - batch_.reset(new leveldb::WriteBatch()); - - return status.ok(); -} - -template -void LeveldbDataset::close() { - DLOG(INFO) << "LevelDB: Close"; - - batch_.reset(); - db_.reset(); -} - -template -void LeveldbDataset::keys(vector* keys) { - DLOG(INFO) << "LevelDB: Keys"; - - keys->clear(); - for (const_iterator iter = begin(); iter != end(); ++iter) { - keys->push_back(iter->key); - } -} - -template -typename LeveldbDataset::const_iterator - LeveldbDataset::begin() const { - CHECK_NOTNULL(db_.get()); - shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); - iter->SeekToFirst(); - if (!iter->Valid()) { - iter.reset(); - } - - shared_ptr state; - if (iter) { - state.reset(new LeveldbState(db_, iter)); - } - return const_iterator(this, state); -} - -template -typename LeveldbDataset::const_iterator - LeveldbDataset::end() const { - shared_ptr state; - return const_iterator(this, state); -} - -template -typename LeveldbDataset::const_iterator - LeveldbDataset::cbegin() const { - return begin(); -} - -template -typename LeveldbDataset::const_iterator - LeveldbDataset::cend() const { return end(); } - -template -bool LeveldbDataset::equal( - shared_ptr state1, shared_ptr state2) const { - shared_ptr leveldb_state1 = - boost::dynamic_pointer_cast(state1); - - shared_ptr leveldb_state2 = - boost::dynamic_pointer_cast(state2); - - // The KV store doesn't really have any sort of ordering, - // so while we can do a sequential scan over the collection, - // we can't really use subranges. - return !leveldb_state1 && !leveldb_state2; -} - -template -void LeveldbDataset::increment( - shared_ptr* state) const { - shared_ptr leveldb_state = - boost::dynamic_pointer_cast(*state); - - CHECK_NOTNULL(leveldb_state.get()); - - shared_ptr& iter = leveldb_state->iter_; - - CHECK_NOTNULL(iter.get()); - CHECK(iter->Valid()); - - iter->Next(); - if (!iter->Valid()) { - state->reset(); - } -} - -template -typename Dataset::KV& - LeveldbDataset::dereference( - shared_ptr state) const { - shared_ptr leveldb_state = - boost::dynamic_pointer_cast(state); - - CHECK_NOTNULL(leveldb_state.get()); - - shared_ptr& iter = leveldb_state->iter_; - - CHECK_NOTNULL(iter.get()); - - CHECK(iter->Valid()); - - const leveldb::Slice& key = iter->key(); - const leveldb::Slice& value = iter->value(); - CHECK(KCoder::deserialize(key.data(), key.size(), - &leveldb_state->kv_pair_.key)); - CHECK(VCoder::deserialize(value.data(), value.size(), - &leveldb_state->kv_pair_.value)); - - return leveldb_state->kv_pair_; -} - -INSTANTIATE_DATASET(LeveldbDataset); - -} // namespace caffe diff --git a/src/caffe/lmdb_dataset.cpp b/src/caffe/lmdb_dataset.cpp deleted file mode 100644 index 8f8e68e901e..00000000000 --- a/src/caffe/lmdb_dataset.cpp +++ /dev/null @@ -1,366 +0,0 @@ -#include - -#include -#include -#include - -#include "caffe/caffe.hpp" -#include "caffe/lmdb_dataset.hpp" - -namespace caffe { - -template -bool LmdbDataset::open(const string& filename, - Mode mode) { - DLOG(INFO) << "LMDB: Open " << filename; - - CHECK(NULL == env_); - CHECK(NULL == write_txn_); - CHECK(NULL == read_txn_); - CHECK_EQ(0, dbi_); - - int retval; - if (mode != Base::ReadOnly) { - retval = mkdir(filename.c_str(), 0744); - switch (mode) { - case Base::New: - if (0 != retval) { - LOG(ERROR) << "mkdir " << filename << " failed"; - return false; - } - break; - case Base::ReadWrite: - if (-1 == retval && EEXIST != errno) { - LOG(ERROR) << "mkdir " << filename << " failed (" - << strerror(errno) << ")"; - return false; - } - break; - default: - LOG(FATAL) << "Invalid mode " << mode; - } - } - - retval = mdb_env_create(&env_); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_env_create failed " - << mdb_strerror(retval); - return false; - } - - retval = mdb_env_set_mapsize(env_, 1099511627776); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_env_set_mapsize failed " << mdb_strerror(retval); - return false; - } - - int flag1 = 0; - int flag2 = 0; - if (mode == Base::ReadOnly) { - flag1 = MDB_RDONLY | MDB_NOTLS; - flag2 = MDB_RDONLY; - } - - retval = mdb_env_open(env_, filename.c_str(), flag1, 0664); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_env_open failed " << mdb_strerror(retval); - return false; - } - - retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &read_txn_); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); - return false; - } - - retval = mdb_txn_begin(env_, NULL, flag2, &write_txn_); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); - return false; - } - - retval = mdb_open(write_txn_, NULL, 0, &dbi_); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_open failed" << mdb_strerror(retval); - return false; - } - - return true; -} - -template -bool LmdbDataset::put(const K& key, const V& value) { - DLOG(INFO) << "LMDB: Put"; - - vector serialized_key; - if (!KCoder::serialize(key, &serialized_key)) { - LOG(ERROR) << "failed to serialize key"; - return false; - } - - vector serialized_value; - if (!VCoder::serialize(value, &serialized_value)) { - LOG(ERROR) << "failed to serialized value"; - return false; - } - - MDB_val mdbkey, mdbdata; - mdbdata.mv_size = serialized_value.size(); - mdbdata.mv_data = serialized_value.data(); - mdbkey.mv_size = serialized_key.size(); - mdbkey.mv_data = serialized_key.data(); - - CHECK_NOTNULL(write_txn_); - CHECK_NE(0, dbi_); - - int retval = mdb_put(write_txn_, dbi_, &mdbkey, &mdbdata, 0); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_put failed " << mdb_strerror(retval); - return false; - } - - return true; -} - -template -bool LmdbDataset::get(const K& key, V* value) { - DLOG(INFO) << "LMDB: Get"; - - vector serialized_key; - if (!KCoder::serialize(key, &serialized_key)) { - LOG(ERROR) << "failed to serialized key"; - return false; - } - - MDB_val mdbkey, mdbdata; - mdbkey.mv_data = serialized_key.data(); - mdbkey.mv_size = serialized_key.size(); - - int retval; - retval = mdb_get(read_txn_, dbi_, &mdbkey, &mdbdata); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_get failed " << mdb_strerror(retval); - return false; - } - - if (!VCoder::deserialize(reinterpret_cast(mdbdata.mv_data), - mdbdata.mv_size, value)) { - LOG(ERROR) << "failed to deserialize value"; - return false; - } - - return true; -} - -template -bool LmdbDataset::first_key(K* key) { - DLOG(INFO) << "LMDB: First key"; - - int retval; - - MDB_cursor* cursor; - retval = mdb_cursor_open(read_txn_, dbi_, &cursor); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - MDB_val mdbkey; - MDB_val mdbval; - retval = mdb_cursor_get(cursor, &mdbkey, &mdbval, MDB_FIRST); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - - mdb_cursor_close(cursor); - - if (!KCoder::deserialize(reinterpret_cast(mdbkey.mv_data), - mdbkey.mv_size, key)) { - LOG(ERROR) << "failed to deserialize key"; - return false; - } - - return true; -} - -template -bool LmdbDataset::last_key(K* key) { - DLOG(INFO) << "LMDB: Last key"; - - int retval; - - MDB_cursor* cursor; - retval = mdb_cursor_open(read_txn_, dbi_, &cursor); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - MDB_val mdbkey; - MDB_val mdbval; - retval = mdb_cursor_get(cursor, &mdbkey, &mdbval, MDB_LAST); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - - mdb_cursor_close(cursor); - - if (!KCoder::deserialize(reinterpret_cast(mdbkey.mv_data), - mdbkey.mv_size, key)) { - LOG(ERROR) << "failed to deserialize key"; - return false; - } - - return true; -} - -template -bool LmdbDataset::commit() { - DLOG(INFO) << "LMDB: Commit"; - - CHECK_NOTNULL(write_txn_); - - int retval; - retval = mdb_txn_commit(write_txn_); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_txn_commit failed " << mdb_strerror(retval); - return false; - } - - mdb_txn_abort(read_txn_); - - retval = mdb_txn_begin(env_, NULL, 0, &write_txn_); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); - return false; - } - - retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &read_txn_); - if (MDB_SUCCESS != retval) { - LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); - return false; - } - - return true; -} - -template -void LmdbDataset::close() { - DLOG(INFO) << "LMDB: Close"; - - if (env_ && dbi_) { - mdb_txn_abort(write_txn_); - mdb_txn_abort(read_txn_); - mdb_close(env_, dbi_); - mdb_env_close(env_); - env_ = NULL; - dbi_ = 0; - write_txn_ = NULL; - read_txn_ = NULL; - } -} - -template -void LmdbDataset::keys(vector* keys) { - DLOG(INFO) << "LMDB: Keys"; - - keys->clear(); - for (const_iterator iter = begin(); iter != end(); ++iter) { - keys->push_back(iter->key); - } -} - -template -typename LmdbDataset::const_iterator - LmdbDataset::begin() const { - int retval; - - MDB_cursor* cursor; - retval = mdb_cursor_open(read_txn_, dbi_, &cursor); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - MDB_val key; - MDB_val val; - retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST); - - CHECK(MDB_SUCCESS == retval || MDB_NOTFOUND == retval) - << mdb_strerror(retval); - - shared_ptr state; - if (MDB_SUCCESS == retval) { - state.reset(new LmdbState(cursor, read_txn_, &dbi_)); - } else { - mdb_cursor_close(cursor); - } - return const_iterator(this, state); -} - -template -typename LmdbDataset::const_iterator - LmdbDataset::end() const { - shared_ptr state; - return const_iterator(this, state); -} - -template -typename LmdbDataset::const_iterator - LmdbDataset::cbegin() const { return begin(); } - -template -typename LmdbDataset::const_iterator - LmdbDataset::cend() const { return end(); } - -template -bool LmdbDataset::equal(shared_ptr state1, - shared_ptr state2) const { - shared_ptr lmdb_state1 = - boost::dynamic_pointer_cast(state1); - - shared_ptr lmdb_state2 = - boost::dynamic_pointer_cast(state2); - - // The KV store doesn't really have any sort of ordering, - // so while we can do a sequential scan over the collection, - // we can't really use subranges. - return !lmdb_state1 && !lmdb_state2; -} - -template -void LmdbDataset::increment( - shared_ptr* state) const { - shared_ptr lmdb_state = - boost::dynamic_pointer_cast(*state); - - CHECK_NOTNULL(lmdb_state.get()); - - MDB_cursor*& cursor = lmdb_state->cursor_; - - CHECK_NOTNULL(cursor); - - MDB_val key; - MDB_val val; - int retval = mdb_cursor_get(cursor, &key, &val, MDB_NEXT); - if (MDB_NOTFOUND == retval) { - mdb_cursor_close(cursor); - state->reset(); - } else { - CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval); - } -} - -template -typename Dataset::KV& - LmdbDataset::dereference( - shared_ptr state) const { - shared_ptr lmdb_state = - boost::dynamic_pointer_cast(state); - - CHECK_NOTNULL(lmdb_state.get()); - - MDB_cursor*& cursor = lmdb_state->cursor_; - - CHECK_NOTNULL(cursor); - - MDB_val mdb_key; - MDB_val mdb_val; - int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT); - CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); - - CHECK(KCoder::deserialize(reinterpret_cast(mdb_key.mv_data), - mdb_key.mv_size, &lmdb_state->kv_pair_.key)); - CHECK(VCoder::deserialize(reinterpret_cast(mdb_val.mv_data), - mdb_val.mv_size, &lmdb_state->kv_pair_.value)); - - return lmdb_state->kv_pair_; -} - -INSTANTIATE_DATASET(LmdbDataset); - -} // namespace caffe diff --git a/src/caffe/test/test_dataset.cpp b/src/caffe/test/test_dataset.cpp deleted file mode 100644 index 6645ca228d2..00000000000 --- a/src/caffe/test/test_dataset.cpp +++ /dev/null @@ -1,794 +0,0 @@ -#include -#include - -#include "caffe/util/io.hpp" - -#include "gtest/gtest.h" - -#include "caffe/dataset_factory.hpp" - -#include "caffe/test/test_caffe_main.hpp" - -namespace caffe { - -namespace DatasetTest_internal { - -template -struct TestData { - static T TestValue(); - static T TestAltValue(); - static bool equals(const T& a, const T& b); -}; - -template <> -string TestData::TestValue() { - return "world"; -} - -template <> -string TestData::TestAltValue() { - return "bar"; -} - -template <> -bool TestData::equals(const string& a, const string& b) { - return a == b; -} - -template <> -vector TestData >::TestValue() { - string str = "world"; - vector val(str.data(), str.data() + str.size()); - return val; -} - -template <> -vector TestData >::TestAltValue() { - string str = "bar"; - vector val(str.data(), str.data() + str.size()); - return val; -} - -template <> -bool TestData >::equals(const vector& a, - const vector& b) { - if (a.size() != b.size()) { - return false; - } - for (size_t i = 0; i < a.size(); ++i) { - if (a.at(i) != b.at(i)) { - return false; - } - } - - return true; -} - -template <> -Datum TestData::TestValue() { - Datum datum; - datum.set_channels(3); - datum.set_height(32); - datum.set_width(32); - datum.set_data(string(32 * 32 * 3 * 4, ' ')); - datum.set_label(0); - return datum; -} - -template <> -Datum TestData::TestAltValue() { - Datum datum; - datum.set_channels(1); - datum.set_height(64); - datum.set_width(64); - datum.set_data(string(64 * 64 * 1 * 4, ' ')); - datum.set_label(1); - return datum; -} - -template <> -bool TestData::equals(const Datum& a, const Datum& b) { - string serialized_a; - a.SerializeToString(&serialized_a); - - string serialized_b; - b.SerializeToString(&serialized_b); - - return serialized_a == serialized_b; -} - -} // namespace DatasetTest_internal - -#define UNPACK_TYPES \ - typedef typename TypeParam::value_type value_type; \ - const DataParameter_DB backend = TypeParam::backend; - -template -class DatasetTest : public ::testing::Test { - protected: - typedef typename TypeParam::value_type value_type; - - string DBName() { - string filename; - MakeTempDir(&filename); - filename += "/db"; - return filename; - } - - string TestKey() { - return "hello"; - } - - value_type TestValue() { - return DatasetTest_internal::TestData::TestValue(); - } - - string TestAltKey() { - return "foo"; - } - - value_type TestAltValue() { - return DatasetTest_internal::TestData::TestAltValue(); - } - - template - bool equals(const T& a, const T& b) { - return DatasetTest_internal::TestData::equals(a, b); - } -}; - -struct StringLeveldb { - typedef string value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB; - -struct StringLmdb { - typedef string value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB StringLmdb::backend = DataParameter_DB_LMDB; - -struct VectorLeveldb { - typedef vector value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB; - -struct VectorLmdb { - typedef vector value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LMDB; - -struct DatumLeveldb { - typedef Datum value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB; - -struct DatumLmdb { - typedef Datum value_type; - static const DataParameter_DB backend; -}; -const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LMDB; - -typedef ::testing::Types TestTypes; - -TYPED_TEST_CASE(DatasetTest, TestTypes); - -TYPED_TEST(DatasetTest, TestNewDoesntExistPasses) { - UNPACK_TYPES; - - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(this->DBName(), - Dataset::New)); - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestNewExistsFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - dataset->close(); - - EXPECT_FALSE(dataset->open(name, Dataset::New)); -} - -TYPED_TEST(DatasetTest, TestReadOnlyExistsPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - dataset->close(); - - EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestReadOnlyDoesntExistFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_FALSE(dataset->open(name, Dataset::ReadOnly)); -} - -TYPED_TEST(DatasetTest, TestReadWriteExistsPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - dataset->close(); - - EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestReadWriteDoesntExistPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestKeys) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key1 = this->TestKey(); - value_type value1 = this->TestValue(); - - EXPECT_TRUE(dataset->put(key1, value1)); - - string key2 = this->TestAltKey(); - value_type value2 = this->TestAltValue(); - - EXPECT_TRUE(dataset->put(key2, value2)); - - EXPECT_TRUE(dataset->commit()); - - vector keys; - dataset->keys(&keys); - - EXPECT_EQ(2, keys.size()); - - EXPECT_TRUE(this->equals(keys.at(0), key1) || - this->equals(keys.at(0), key2)); - EXPECT_TRUE(this->equals(keys.at(1), key1) || - this->equals(keys.at(2), key2)); - EXPECT_FALSE(this->equals(keys.at(0), keys.at(1))); -} - -TYPED_TEST(DatasetTest, TestFirstKey) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - value_type value = this->TestValue(); - - string key1 = "01"; - EXPECT_TRUE(dataset->put(key1, value)); - - string key2 = "02"; - EXPECT_TRUE(dataset->put(key2, value)); - - string key3 = "03"; - EXPECT_TRUE(dataset->put(key3, value)); - - EXPECT_TRUE(dataset->commit()); - - string first_key; - dataset->first_key(&first_key); - - EXPECT_TRUE(this->equals(first_key, key1)); -} - -TYPED_TEST(DatasetTest, TestLastKey) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - value_type value = this->TestValue(); - - string key1 = "01"; - EXPECT_TRUE(dataset->put(key1, value)); - - string key2 = "02"; - EXPECT_TRUE(dataset->put(key2, value)); - - string key3 = "03"; - EXPECT_TRUE(dataset->put(key3, value)); - - EXPECT_TRUE(dataset->commit()); - - string last_key; - dataset->last_key(&last_key); - - EXPECT_TRUE(this->equals(last_key, key3)); -} - -TYPED_TEST(DatasetTest, TestFirstLastKeys) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - value_type value = this->TestValue(); - - string key1 = "01"; - EXPECT_TRUE(dataset->put(key1, value)); - - string key2 = "02"; - EXPECT_TRUE(dataset->put(key2, value)); - - string key3 = "03"; - EXPECT_TRUE(dataset->put(key3, value)); - - EXPECT_TRUE(dataset->commit()); - - string first_key; - dataset->first_key(&first_key); - string last_key; - dataset->last_key(&last_key); - - EXPECT_TRUE(this->equals(first_key, key1)); - EXPECT_TRUE(this->equals(last_key, key3)); -} - -TYPED_TEST(DatasetTest, TestFirstLastKeysUnOrdered) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - value_type value = this->TestValue(); - - string key3 = "03"; - EXPECT_TRUE(dataset->put(key3, value)); - - string key1 = "01"; - EXPECT_TRUE(dataset->put(key1, value)); - - string key2 = "02"; - EXPECT_TRUE(dataset->put(key2, value)); - - EXPECT_TRUE(dataset->commit()); - - string first_key; - dataset->first_key(&first_key); - string last_key; - dataset->last_key(&last_key); - - EXPECT_TRUE(this->equals(first_key, key1)); - EXPECT_TRUE(this->equals(last_key, key3)); -} - -TYPED_TEST(DatasetTest, TestKeysNoCommit) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key1 = this->TestKey(); - value_type value1 = this->TestValue(); - - EXPECT_TRUE(dataset->put(key1, value1)); - - string key2 = this->TestAltKey(); - value_type value2 = this->TestAltValue(); - - EXPECT_TRUE(dataset->put(key2, value2)); - - vector keys; - dataset->keys(&keys); - - EXPECT_EQ(0, keys.size()); -} - -TYPED_TEST(DatasetTest, TestIterators) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - const int kNumExamples = 4; - for (int i = 0; i < kNumExamples; ++i) { - stringstream ss; - ss << i; - string key = ss.str(); - ss << " here be data"; - value_type value = this->TestValue(); - EXPECT_TRUE(dataset->put(key, value)); - } - EXPECT_TRUE(dataset->commit()); - - int count = 0; - typedef typename Dataset::const_iterator Iter; - for (Iter iter = dataset->begin(); iter != dataset->end(); ++iter) { - (void)iter; - ++count; - } - - EXPECT_EQ(kNumExamples, count); -} - -TYPED_TEST(DatasetTest, TestIteratorsPreIncrement) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key1 = this->TestAltKey(); - value_type value1 = this->TestAltValue(); - - string key2 = this->TestKey(); - value_type value2 = this->TestValue(); - - EXPECT_TRUE(dataset->put(key1, value1)); - EXPECT_TRUE(dataset->put(key2, value2)); - EXPECT_TRUE(dataset->commit()); - - typename Dataset::const_iterator iter1 = - dataset->begin(); - - EXPECT_FALSE(dataset->end() == iter1); - - EXPECT_TRUE(this->equals(iter1->key, key1)); - - typename Dataset::const_iterator iter2 = ++iter1; - - EXPECT_FALSE(dataset->end() == iter1); - EXPECT_FALSE(dataset->end() == iter2); - - EXPECT_TRUE(this->equals(iter2->key, key2)); - - typename Dataset::const_iterator iter3 = ++iter2; - - EXPECT_TRUE(dataset->end() == iter3); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestIteratorsPostIncrement) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key1 = this->TestAltKey(); - value_type value1 = this->TestAltValue(); - - string key2 = this->TestKey(); - value_type value2 = this->TestValue(); - - EXPECT_TRUE(dataset->put(key1, value1)); - EXPECT_TRUE(dataset->put(key2, value2)); - EXPECT_TRUE(dataset->commit()); - - typename Dataset::const_iterator iter1 = - dataset->begin(); - - EXPECT_FALSE(dataset->end() == iter1); - - EXPECT_TRUE(this->equals(iter1->key, key1)); - - typename Dataset::const_iterator iter2 = iter1++; - - EXPECT_FALSE(dataset->end() == iter1); - EXPECT_FALSE(dataset->end() == iter2); - - EXPECT_TRUE(this->equals(iter2->key, key1)); - EXPECT_TRUE(this->equals(iter1->key, key2)); - - typename Dataset::const_iterator iter3 = iter1++; - - EXPECT_FALSE(dataset->end() == iter3); - EXPECT_TRUE(this->equals(iter3->key, key2)); - EXPECT_TRUE(dataset->end() == iter1); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestNewPutPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - EXPECT_TRUE(dataset->commit()); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestNewCommitPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - EXPECT_TRUE(dataset->commit()); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestNewGetPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - EXPECT_TRUE(dataset->commit()); - - value_type new_value; - - EXPECT_TRUE(dataset->get(key, &new_value)); - - EXPECT_TRUE(this->equals(value, new_value)); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestNewGetNoCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - value_type new_value; - - EXPECT_FALSE(dataset->get(key, &new_value)); -} - - -TYPED_TEST(DatasetTest, TestReadWritePutPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - EXPECT_TRUE(dataset->commit()); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestReadWriteCommitPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); - - EXPECT_TRUE(dataset->commit()); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestReadWriteGetPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - EXPECT_TRUE(dataset->commit()); - - value_type new_value; - - EXPECT_TRUE(dataset->get(key, &new_value)); - - EXPECT_TRUE(this->equals(value, new_value)); - - dataset->close(); -} - -TYPED_TEST(DatasetTest, TestReadWriteGetNoCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - value_type new_value; - - EXPECT_FALSE(dataset->get(key, &new_value)); -} - -TYPED_TEST(DatasetTest, TestReadOnlyPutFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - dataset->close(); - - EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_FALSE(dataset->put(key, value)); -} - -TYPED_TEST(DatasetTest, TestReadOnlyCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - dataset->close(); - - EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); - - EXPECT_FALSE(dataset->commit()); -} - -TYPED_TEST(DatasetTest, TestReadOnlyGetPasses) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - EXPECT_TRUE(dataset->commit()); - - dataset->close(); - - EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); - - value_type new_value; - - EXPECT_TRUE(dataset->get(key, &new_value)); - - EXPECT_TRUE(this->equals(value, new_value)); -} - -TYPED_TEST(DatasetTest, TestReadOnlyGetNoCommitFails) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - - EXPECT_TRUE(dataset->put(key, value)); - - dataset->close(); - - EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); - - value_type new_value; - - EXPECT_FALSE(dataset->get(key, &new_value)); -} - -TYPED_TEST(DatasetTest, TestCreateManyItersShortScope) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - EXPECT_TRUE(dataset->put(key, value)); - EXPECT_TRUE(dataset->commit()); - - for (int i = 0; i < 1000; ++i) { - typename Dataset::const_iterator iter = - dataset->begin(); - } -} - -TYPED_TEST(DatasetTest, TestCreateManyItersLongScope) { - UNPACK_TYPES; - - string name = this->DBName(); - shared_ptr > dataset = - DatasetFactory(backend); - EXPECT_TRUE(dataset->open(name, Dataset::New)); - - string key = this->TestKey(); - value_type value = this->TestValue(); - EXPECT_TRUE(dataset->put(key, value)); - EXPECT_TRUE(dataset->commit()); - - vector::const_iterator> iters; - for (int i = 0; i < 1000; ++i) { - iters.push_back(dataset->begin()); - } -} - -#undef UNPACK_TYPES - -} // namespace caffe