Skip to content

Commit

Permalink
Merge pull request #209 from lukeyeager/lmdb-map-size
Browse files Browse the repository at this point in the history
Double LMDB map_size on MapFullError
  • Loading branch information
lukeyeager committed Aug 12, 2015
2 parents 6e84807 + 944bea6 commit 7eb15d4
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 31 deletions.
37 changes: 22 additions & 15 deletions digits/dataset/images/generic/test_lmdb_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,12 @@ def create_lmdbs(folder, image_width=None, image_height=None, image_count=None):
('train', train_image_count),
('val', val_image_count)]:
image_db = lmdb.open(os.path.join(folder, '%s_images' % phase),
map_size=1024**4, # 1TB
map_async=True,
max_dbs=0)
label_db = lmdb.open(os.path.join(folder, '%s_labels' % phase),
map_size=1024**4, # 1TB
map_async=True,
max_dbs=0)

write_batch_size = 10

image_txn = image_db.begin(write=True)
label_txn = label_db.begin(write=True)

image_sum = np.zeros((image_height, image_width), 'float64')

for i in xrange(image_count):
Expand All @@ -101,19 +94,13 @@ def create_lmdbs(folder, image_width=None, image_height=None, image_count=None):
pil_img.save(s, format='PNG')
image_datum.data = s.getvalue()
image_datum.encoded = True
image_txn.put(str(i), image_datum.SerializeToString())
_write_to_lmdb(image_db, str(i), image_datum.SerializeToString())

# create label Datum
label_datum = caffe_pb2.Datum()
label_datum.channels, label_datum.height, label_datum.width = 1, 1, 2
label_datum.float_data.extend(np.array([xslope, yslope]).flat)
label_txn.put(str(i), label_datum.SerializeToString())

if ((i+1)%write_batch_size) == 0:
image_txn.commit()
image_txn = image_db.begin(write=True)
label_txn.commit()
label_txn = label_db.begin(write=True)
_write_to_lmdb(label_db, str(i), label_datum.SerializeToString())

# close databases
image_db.close()
Expand All @@ -137,6 +124,26 @@ def create_lmdbs(folder, image_width=None, image_height=None, image_count=None):

return test_image_filename

def _write_to_lmdb(db, key, value):
"""
Write (key,value) to db
"""
success = False
while not success:
txn = db.begin(write=True)
try:
txn.put(key, value)
txn.commit()
success = True
except lmdb.MapFullError:
txn.abort()

# double the map_size
curr_limit = db.info()['map_size']
new_limit = curr_limit*2
print '>>> Doubling LMDB map size to %sMB ...' % (new_limit>>20,)
db.set_mapsize(new_limit) # double it

def _save_mean(mean, filename):
"""
Saves mean to file
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ gevent>=1.0
Flask>=0.10.1
Flask-WTF>=0.11
gunicorn==17.5
lmdb>=0.87
pydot2
Flask-SocketIO
lmdb
41 changes: 32 additions & 9 deletions tools/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,34 @@

logger = logging.getLogger('digits.tools.create_db')

class DbCreator:
class DbCreator(object):
"""
Creates a database for a neural network imageset
"""

def __init__(self, db_path):
def __init__(self, db_path, lmdb_map_size=None):
"""
Arguments:
db_path -- where should the database be created
Keyword arguments:
lmdb_map_size -- the initial LMDB map size
"""
# Can have trailing slash or not
self.output_path = os.path.dirname(os.path.join(db_path, ''))
self.name = os.path.basename(self.output_path)

if lmdb_map_size:
# convert from MB to B
lmdb_map_size <<= 20

if os.path.exists(self.output_path):
# caffe throws an error instead
logger.warning('removing existing database %s' % self.output_path)
rmtree(self.output_path, ignore_errors=True)

self.db = lmdb.open(self.output_path,
map_size=1000000000000, # ~1TB
map_size=lmdb_map_size,
map_async=True,
max_dbs=0)

Expand Down Expand Up @@ -105,7 +112,7 @@ def create(self, input_file, width, height,
logger.error('unsupported number of channels')
return False
self.channels = channels
if resize_mode not in ['crop', 'squash', 'fill', 'half_crop']:
if resize_mode not in [None, 'crop', 'squash', 'fill', 'half_crop']:
logger.error('unsupported resize_mode')
return False
self.resize_mode = resize_mode
Expand Down Expand Up @@ -138,10 +145,9 @@ def create(self, input_file, width, height,
#XXX This is the only way to preserve order for now
# This obviously hurts performance considerably
read_threads = 1
write_threads = 1
else:
read_threads = 10
write_threads = 10
write_threads = 1
batch_size = 100

total_images_added = 0
Expand Down Expand Up @@ -448,7 +454,21 @@ def write_batch(self, batch):
keys = self.get_keys(len(batch))
lmdb_txn = self.db.begin(write=True)
for i, datum in enumerate(batch):
lmdb_txn.put('%08d_%d' % (keys[i], datum.label), datum.SerializeToString())
try:
key = '%08d_%d' % (keys[i], datum.label)
lmdb_txn.put(key, datum.SerializeToString())
except lmdb.MapFullError:
lmdb_txn.abort()

# double the map_size
curr_limit = self.db.info()['map_size']
new_limit = curr_limit*2
logger.debug('Doubling LMDB map size to %sMB ...' % (new_limit>>20,))
self.db.set_mapsize(new_limit) # double it

# try again
self.write_batch(batch)
return
lmdb_txn.commit()

def get_keys(self, num):
Expand Down Expand Up @@ -493,7 +513,6 @@ def get_keys(self, num):
help='channels of resized images (1 for grayscale, 3 for color [default])'
)
parser.add_argument('-r', '--resize_mode',
default='squash',
help='resize mode for images (must be "crop", "squash" [default], "fill" or "half_crop")'
)
parser.add_argument('-m', '--mean_file', action='append',
Expand All @@ -508,10 +527,14 @@ def get_keys(self, num):
default = 'none',
help = 'Choose encoding format ("jpg", "png" or "none" [default])'
)
parser.add_argument('--lmdb_map_size',
type=int,
help = 'The initial map size for LMDB (in MB)')

args = vars(parser.parse_args())

db = DbCreator(args['db_name'])
db = DbCreator(args['db_name'],
lmdb_map_size = args['lmdb_map_size'])

if db.create(args['input_file'], args['width'], args['height'],
channels = args['channels'],
Expand Down
50 changes: 44 additions & 6 deletions tools/test_create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import tempfile
import shutil
from cStringIO import StringIO
import unittest
import platform

from nose.tools import raises, assert_raises
import mock
import unittest
import PIL.Image
import numpy as np

from . import create_db as _

class TestInit():
class TestInit(object):
@classmethod
def setUpClass(cls):
cls.db_name = tempfile.mkdtemp()
Expand All @@ -25,7 +26,7 @@ def tearDownClass(cls):
except OSError:
pass

class TestCreate():
class TestCreate(object):
@classmethod
def setUpClass(cls):
cls.db_name = tempfile.mkdtemp()
Expand Down Expand Up @@ -88,7 +89,7 @@ def test_create_normal(self):
resize_mode='crop'), 'database should complete building normally'


class TestPathToDatum():
class TestPathToDatum(object):
@classmethod
def setUpClass(cls):
cls.tmpdir = tempfile.mkdtemp()
Expand Down Expand Up @@ -128,5 +129,42 @@ def check_configs(self, args):
else:
assert d.encoded, 'datum should be encoded when encoding="%s"' % e

class TestSaveMean():
pass

class TestMapSize(object):
"""
Tests regarding the LMDB map_size argument
"""

def test_default_mapsize(self):
db_name = tempfile.mkdtemp()
db = _.DbCreator(db_name)
assert db.db.info()['map_size'] == (10<<20), 'Default map_size %s != 10MB' % db.db.info()['map_size']

@unittest.skipIf(platform.system() != 'Linux',
'This test fails on non-Linux systems')
def test_huge_mapsize(self):
db_name = tempfile.mkdtemp()
mapsize_mb = 1024**2 # 1TB should be no problem
db = _.DbCreator(db_name, lmdb_map_size=mapsize_mb)

def test_set_mapsize(self):
# create textfile
fd, input_file = tempfile.mkstemp()
os.close(fd)
with open(input_file, 'w') as f:
f.write('digits/static/images/mona_lisa.jpg 0')

# create DbCreator object
db_name = tempfile.mkdtemp()
mapsize_mb = 1
db = _.DbCreator(db_name, lmdb_map_size=mapsize_mb)

# create db
image_size = 1000 # big enough to trigger a MapFullError
assert db.create(
input_file,
width=image_size,
height=image_size,
), 'create failed'


0 comments on commit 7eb15d4

Please sign in to comment.