diff --git a/digits/dataset/images/generic/test_lmdb_creator.py b/digits/dataset/images/generic/test_lmdb_creator.py index 6b44bcffa..eed80fc91 100755 --- a/digits/dataset/images/generic/test_lmdb_creator.py +++ b/digits/dataset/images/generic/test_lmdb_creator.py @@ -65,19 +65,12 @@ def create_lmdbs(folder, image_width=None, image_height=None, image_count=None): ('train', train_image_count), ('val', val_image_count)]: image_db = lmdb.open(os.path.join(folder, '%s_images' % phase), - map_size=1024**4, # 1TB map_async=True, max_dbs=0) label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase), - map_size=1024**4, # 1TB map_async=True, max_dbs=0) - write_batch_size = 10 - - image_txn = image_db.begin(write=True) - label_txn = label_db.begin(write=True) - image_sum = np.zeros((image_height, image_width), 'float64') for i in xrange(image_count): @@ -101,19 +94,13 @@ def create_lmdbs(folder, image_width=None, image_height=None, image_count=None): pil_img.save(s, format='PNG') image_datum.data = s.getvalue() image_datum.encoded = True - image_txn.put(str(i), image_datum.SerializeToString()) + _write_to_lmdb(image_db, str(i), image_datum.SerializeToString()) # create label Datum label_datum = caffe_pb2.Datum() label_datum.channels, label_datum.height, label_datum.width = 1, 1, 2 label_datum.float_data.extend(np.array([xslope, yslope]).flat) - label_txn.put(str(i), label_datum.SerializeToString()) - - if ((i+1)%write_batch_size) == 0: - image_txn.commit() - image_txn = image_db.begin(write=True) - label_txn.commit() - label_txn = label_db.begin(write=True) + _write_to_lmdb(label_db, str(i), label_datum.SerializeToString()) # close databases image_db.close() @@ -137,6 +124,26 @@ def create_lmdbs(folder, image_width=None, image_height=None, image_count=None): return test_image_filename +def _write_to_lmdb(db, key, value): + """ + Write (key,value) to db + """ + success = False + while not success: + txn = db.begin(write=True) + try: + txn.put(key, value) + txn.commit() + success = True + except lmdb.MapFullError: + txn.abort() + + # double the map_size + curr_limit = db.info()['map_size'] + new_limit = curr_limit*2 + print '>>> Doubling LMDB map size to %sMB ...' % (new_limit>>20,) + db.set_mapsize(new_limit) # double it + def _save_mean(mean, filename): """ Saves mean to file diff --git a/requirements.txt b/requirements.txt index 8572d29ad..29f1397d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,6 @@ gevent>=1.0 Flask>=0.10.1 Flask-WTF>=0.11 gunicorn==17.5 +lmdb>=0.87 pydot2 Flask-SocketIO -lmdb diff --git a/tools/create_db.py b/tools/create_db.py index aabab439a..55af4dbc8 100755 --- a/tools/create_db.py +++ b/tools/create_db.py @@ -35,27 +35,34 @@ logger = logging.getLogger('digits.tools.create_db') -class DbCreator: +class DbCreator(object): """ Creates a database for a neural network imageset """ - def __init__(self, db_path): + def __init__(self, db_path, lmdb_map_size=None): """ Arguments: db_path -- where should the database be created + + Keyword arguments: + lmdb_map_size -- the initial LMDB map size """ # Can have trailing slash or not self.output_path = os.path.dirname(os.path.join(db_path, '')) self.name = os.path.basename(self.output_path) + if lmdb_map_size: + # convert from MB to B + lmdb_map_size <<= 20 + if os.path.exists(self.output_path): # caffe throws an error instead logger.warning('removing existing database %s' % self.output_path) rmtree(self.output_path, ignore_errors=True) self.db = lmdb.open(self.output_path, - map_size=1000000000000, # ~1TB + map_size=lmdb_map_size, map_async=True, max_dbs=0) @@ -105,7 +112,7 @@ def create(self, input_file, width, height, logger.error('unsupported number of channels') return False self.channels = channels - if resize_mode not in ['crop', 'squash', 'fill', 'half_crop']: + if resize_mode not in [None, 'crop', 'squash', 'fill', 'half_crop']: logger.error('unsupported resize_mode') return False self.resize_mode = resize_mode @@ -138,10 +145,9 @@ def create(self, input_file, width, height, #XXX This is the only way to preserve order for now # This obviously hurts performance considerably read_threads = 1 - write_threads = 1 else: read_threads = 10 - write_threads = 10 + write_threads = 1 batch_size = 100 total_images_added = 0 @@ -448,7 +454,21 @@ def write_batch(self, batch): keys = self.get_keys(len(batch)) lmdb_txn = self.db.begin(write=True) for i, datum in enumerate(batch): - lmdb_txn.put('%08d_%d' % (keys[i], datum.label), datum.SerializeToString()) + try: + key = '%08d_%d' % (keys[i], datum.label) + lmdb_txn.put(key, datum.SerializeToString()) + except lmdb.MapFullError: + lmdb_txn.abort() + + # double the map_size + curr_limit = self.db.info()['map_size'] + new_limit = curr_limit*2 + logger.debug('Doubling LMDB map size to %sMB ...' % (new_limit>>20,)) + self.db.set_mapsize(new_limit) # double it + + # try again + self.write_batch(batch) + return lmdb_txn.commit() def get_keys(self, num): @@ -493,7 +513,6 @@ def get_keys(self, num): help='channels of resized images (1 for grayscale, 3 for color [default])' ) parser.add_argument('-r', '--resize_mode', - default='squash', help='resize mode for images (must be "crop", "squash" [default], "fill" or "half_crop")' ) parser.add_argument('-m', '--mean_file', action='append', @@ -508,10 +527,14 @@ def get_keys(self, num): default = 'none', help = 'Choose encoding format ("jpg", "png" or "none" [default])' ) + parser.add_argument('--lmdb_map_size', + type=int, + help = 'The initial map size for LMDB (in MB)') args = vars(parser.parse_args()) - db = DbCreator(args['db_name']) + db = DbCreator(args['db_name'], + lmdb_map_size = args['lmdb_map_size']) if db.create(args['input_file'], args['width'], args['height'], channels = args['channels'], diff --git a/tools/test_create_db.py b/tools/test_create_db.py index c467fe25b..4e7a8f063 100644 --- a/tools/test_create_db.py +++ b/tools/test_create_db.py @@ -4,16 +4,17 @@ import tempfile import shutil from cStringIO import StringIO +import unittest +import platform from nose.tools import raises, assert_raises import mock -import unittest import PIL.Image import numpy as np from . import create_db as _ -class TestInit(): +class TestInit(object): @classmethod def setUpClass(cls): cls.db_name = tempfile.mkdtemp() @@ -25,7 +26,7 @@ def tearDownClass(cls): except OSError: pass -class TestCreate(): +class TestCreate(object): @classmethod def setUpClass(cls): cls.db_name = tempfile.mkdtemp() @@ -88,7 +89,7 @@ def test_create_normal(self): resize_mode='crop'), 'database should complete building normally' -class TestPathToDatum(): +class TestPathToDatum(object): @classmethod def setUpClass(cls): cls.tmpdir = tempfile.mkdtemp() @@ -128,5 +129,42 @@ def check_configs(self, args): else: assert d.encoded, 'datum should be encoded when encoding="%s"' % e -class TestSaveMean(): - pass + +class TestMapSize(object): + """ + Tests regarding the LMDB map_size argument + """ + + def test_default_mapsize(self): + db_name = tempfile.mkdtemp() + db = _.DbCreator(db_name) + assert db.db.info()['map_size'] == (10<<20), 'Default map_size %s != 10MB' % db.db.info()['map_size'] + + @unittest.skipIf(platform.system() != 'Linux', + 'This test fails on non-Linux systems') + def test_huge_mapsize(self): + db_name = tempfile.mkdtemp() + mapsize_mb = 1024**2 # 1TB should be no problem + db = _.DbCreator(db_name, lmdb_map_size=mapsize_mb) + + def test_set_mapsize(self): + # create textfile + fd, input_file = tempfile.mkstemp() + os.close(fd) + with open(input_file, 'w') as f: + f.write('digits/static/images/mona_lisa.jpg 0') + + # create DbCreator object + db_name = tempfile.mkdtemp() + mapsize_mb = 1 + db = _.DbCreator(db_name, lmdb_map_size=mapsize_mb) + + # create db + image_size = 1000 # big enough to trigger a MapFullError + assert db.create( + input_file, + width=image_size, + height=image_size, + ), 'create failed' + +