Skip to content

Commit

Permalink
[engine] bug fix: to pickle/unpickle within the proper context
Browse files Browse the repository at this point in the history
We have two places that pickle and unpickle functions are not used within
the proper context.

* StringIO buffer for write, saving buffer to storage happens outside its context.
* lmdb buffer for read, the invalid buffer was used outside transction context.

The earlier attemp is probably red herring: https://rbcommons.com/s/twitter/r/3751/.
Thanks Stu for catching!

These two behaviors are clearly documented, see [1] and [2].
[1] https://docs.python.org/2/library/stringio.html
[2] https://lmdb.readthedocs.org/en/release/

Testing Done:
https://travis-ci.org/peiyuwang/pants/builds/125890340 passed.
https://travis-ci.org/peiyuwang/pants/builds/125941151 passed.

Bugs closed: 3149, 3274

Reviewed at https://rbcommons.com/s/twitter/r/3761/
  • Loading branch information
peiyuwang committed Apr 27, 2016
1 parent 94388b3 commit f22f8cc
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
62 changes: 41 additions & 21 deletions src/python/pants/engine/exp/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@
from pants.util.meta import AbstractClass


def _unpickle(value):
if isinstance(value, six.binary_type):
# Deserialize string values.
return pickle.loads(value)
# Deserialize values with file interface,
return pickle.load(value)


def _identity(value):
return value


def _copy_bytes(value):
return bytes(value)


@total_ordering
class Key(object):
"""Holds the digest for the object, which uniquely identifies it.
Expand Down Expand Up @@ -171,16 +187,16 @@ def put(self, obj):
pickler.fast = 1
pickler.dump(obj)
blob = buf.getvalue()
# Hash the blob and store it if it does not exist.
key = Key.create(blob, type(obj), str(obj) if self._debug else None)

self._contents.put(key.digest, blob)
except Exception as e:
# Unfortunately, pickle can raise things other than PickleError instances. For example it
# will raise ValueError when handed a lambda; so we handle the otherwise overly-broad
# `Exception` type here.
raise SerializationError('Failed to pickle {}: {}'.format(obj, e))

# Hash the blob and store it if it does not exist.
key = Key.create(blob, type(obj), str(obj) if self._debug else None)

self._contents.put(key.digest, blob)
return key

def puts(self, objs):
Expand All @@ -202,12 +218,8 @@ def get(self, key):
if not isinstance(key, Key):
raise InvalidKeyError('Not a valid key: {}'.format(key))

value = self._contents.get(key.digest)
if isinstance(value, six.binary_type):
# loads for string-like values
return self._assert_type_matches(pickle.loads(value), key.type)
# load for file-like value from buffers
return self._assert_type_matches(pickle.load(value), key.type)
value = self._contents.get(key.digest, _unpickle)
return self._assert_type_matches(value, key.type)

def add_mapping(self, from_key, to_key):
"""Establish one to one relationship from one Key to another Key.
Expand Down Expand Up @@ -390,21 +402,26 @@ def __repr__(self):

class KeyValueStore(Closable, AbstractClass):
@abstractmethod
def get(self, key):
def get(self, key, transform=_identity):
"""Fetch the value for a given key.
:param key: key in bytestring.
:param transform: optional function that is applied on the retrieved value from storage
before it is returned, since the original value may be only valid within the context.
:return: value can be either string-like or file-like, `None` if does not exist.
"""

@abstractmethod
def put(self, key, value):
def put(self, key, value, transform=_copy_bytes):
"""Save the value under a key, but only once.
The write once semantics is specifically provided for the content addressable use case.
:param key: key in bytestring.
:param value: value in bytestring.
:param transform: optional function that is applied on the input value before it is
saved to the storage, since the original value may be only valid within the context,
default is to play safe and make a copy.
:return: `True` to indicate the write actually happens, i.e, first write, `False` for
repeated writes of the same key.
"""
Expand All @@ -423,13 +440,13 @@ class InMemoryDb(KeyValueStore):
def __init__(self):
self._storage = dict()

def get(self, key):
return self._storage.get(key)
def get(self, key, transform=_identity):
return transform(self._storage.get(key))

def put(self, key, value):
def put(self, key, value, transform=_copy_bytes):
if key in self._storage:
return False
self._storage[key] = value
self._storage[key] = transform(value)
return True

def items(self):
Expand Down Expand Up @@ -483,7 +500,7 @@ def __init__(self, env, db=None):
def path(self):
return self._env.path()

def get(self, key):
def get(self, key, transform=_identity):
"""Return the value or `None` if the key does not exist.
NB: Memory mapped storage returns a buffer object without copying keys or values, which
Expand All @@ -493,13 +510,16 @@ def get(self, key):
with self._env.begin(db=self._db, buffers=True) as txn:
value = txn.get(key)
if value is not None:
return StringIO.StringIO(value)
return transform(StringIO.StringIO(value))
return None

def put(self, key, value):
"""Returning True if the key/value are actually written to the storage."""
def put(self, key, value, transform=_identity):
"""Returning True if the key/value are actually written to the storage.
No need to do additional transform since value is to be persisted.
"""
with self._env.begin(db=self._db, buffers=True, write=True) as txn:
return txn.put(key, value, overwrite=False)
return txn.put(key, transform(value), overwrite=False)

def items(self):
with self._env.begin(db=self._db, buffers=True) as txn:
Expand Down
2 changes: 0 additions & 2 deletions tests/python/pants_test/engine/exp/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_serial_engine_simple(self):
engine = LocalSerialEngine(self.scheduler, self.storage, self.cache)
self.assert_engine(engine)

@unittest.skip('https://github.com/pantsbuild/pants/issues/3149')
def test_multiprocess_engine_multi(self):
with self.multiprocessing_engine() as engine:
self.assert_engine(engine)
Expand All @@ -62,7 +61,6 @@ def test_multiprocess_unpickleable(self):
with self.assertRaises(SerializationError):
engine.execute(build_request)

@unittest.skip('https://github.com/pantsbuild/pants/issues/3149')
def test_rerun_with_cache(self):
with self.multiprocessing_engine() as engine:
self.assert_engine(engine)
Expand Down

0 comments on commit f22f8cc

Please sign in to comment.