Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove all references to LevelDB #203

Merged
merged 1 commit into from
Aug 10, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions digits/dataset/tasks/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, input_file, db_name, image_dims, **kwargs):
resize_mode -- used in utils.image.resize_image()
encoding -- 'none', 'png' or 'jpg'
mean_file -- save mean file to this location
backend -- type of database to use
labels_file -- used to print category distribution
"""
# Take keyword arguments out of kwargs
Expand All @@ -39,7 +38,6 @@ def __init__(self, input_file, db_name, image_dims, **kwargs):
self.resize_mode = kwargs.pop('resize_mode' , None)
self.encoding = kwargs.pop('encoding', None)
self.mean_file = kwargs.pop('mean_file', None)
self.backend = kwargs.pop('backend', None)
self.labels_file = kwargs.pop('labels_file', None)

super(CreateDbTask, self).__init__(**kwargs)
Expand Down
45 changes: 10 additions & 35 deletions tools/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np
import PIL.Image
import leveldb
import lmdb
from cStringIO import StringIO
# must call digits.config.load_config() before caffe to set the path
Expand All @@ -41,13 +40,10 @@ class DbCreator:
Creates a database for a neural network imageset
"""

def __init__(self, db_path, backend='lmdb'):
def __init__(self, db_path):
"""
Arguments:
db_path -- where should the database be created

Keyword arguments:
backend -- 'lmdb' or 'leveldb'
"""
# Can have trailing slash or not
self.output_path = os.path.dirname(os.path.join(db_path, ''))
Expand All @@ -58,17 +54,10 @@ def __init__(self, db_path, backend='lmdb'):
logger.warning('removing existing database %s' % self.output_path)
rmtree(self.output_path, ignore_errors=True)

if backend == 'lmdb':
self.backend = 'lmdb'
self.db = lmdb.open(self.output_path,
map_size=1000000000000, # ~1TB
map_async=True,
max_dbs=0)
elif backend == 'leveldb':
self.backend = 'leveldb'
self.db = leveldb.LevelDB(self.output_path, error_if_exists=True)
else:
raise ValueError('unknown backend: "%s"' % backend)
self.db = lmdb.open(self.output_path,
map_size=1000000000000, # ~1TB
map_async=True,
max_dbs=0)

self.shutdown = threading.Event()
self.keys_lock = threading.Lock()
Expand Down Expand Up @@ -457,19 +446,10 @@ def write_batch(self, batch):
batch -- an array of Datums
"""
keys = self.get_keys(len(batch))
if self.backend == 'lmdb':
lmdb_txn = self.db.begin(write=True)
for i, datum in enumerate(batch):
lmdb_txn.put('%08d_%d' % (keys[i], datum.label), datum.SerializeToString())
lmdb_txn.commit()
elif self.backend == 'leveldb':
leveldb_batch = leveldb.WriteBatch()
for i, datum in enumerate(batch):
leveldb_batch.Put('%08d_%d' % (keys[i], datum.label), datum.SerializeToString())
self.db.Write(leveldb_batch)
else:
logger.error('unsupported backend')
return False
lmdb_txn = self.db.begin(write=True)
for i, datum in enumerate(batch):
lmdb_txn.put('%08d_%d' % (keys[i], datum.label), datum.SerializeToString())
lmdb_txn.commit()

def get_keys(self, num):
"""
Expand Down Expand Up @@ -524,19 +504,14 @@ def get_keys(self, num):
action='store_true',
help='Shuffle images before saving'
)
parser.add_argument('-b', '--backend',
default='lmdb',
help='db backend [default=lmdb]'
)
parser.add_argument('-e', '--encoding',
default = 'none',
help = 'Choose encoding format ("jpg", "png" or "none" [default])'
)

args = vars(parser.parse_args())

db = DbCreator(args['db_name'],
backend=args['backend'])
db = DbCreator(args['db_name'])

if db.create(args['input_file'], args['width'], args['height'],
channels = args['channels'],
Expand Down
8 changes: 2 additions & 6 deletions tools/test_create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,11 @@ def tearDownClass(cls):
except OSError:
pass

@raises(ValueError)
def test_bad_backend(self):
_.DbCreator(self.db_name, 'not-a-backend')

class TestCreate():
@classmethod
def setUpClass(cls):
cls.db_name = tempfile.mkdtemp()
cls.db = _.DbCreator(cls.db_name, 'leveldb')
cls.db = _.DbCreator(cls.db_name)

fd, cls.input_file = tempfile.mkstemp()
os.close(fd)
Expand Down Expand Up @@ -97,7 +93,7 @@ class TestPathToDatum():
def setUpClass(cls):
cls.tmpdir = tempfile.mkdtemp()
cls.db_name = tempfile.mkdtemp(dir=cls.tmpdir)
cls.db = _.DbCreator(cls.db_name, 'lmdb')
cls.db = _.DbCreator(cls.db_name)
_handle, cls.image_path = tempfile.mkstemp(dir=cls.tmpdir, suffix='.jpg')
with open(cls.image_path, 'w') as outfile:
PIL.Image.fromarray(np.zeros((10,10,3),dtype=np.uint8)).save(outfile, format='JPEG', quality=100)
Expand Down