Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👌 IMPROVE: Allow numpy arrays to be serialized on process checkpoints #4730

Merged
merged 6 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiida/engine/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def load_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> plumpy.pe
raise PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint')

try:
bundle = serialize.deserialize(checkpoint)
bundle = serialize.deserialize_unsafe(checkpoint)
except Exception:
raise PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}')

Expand Down
2 changes: 1 addition & 1 deletion aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def decode_input_args(self, encoded: str) -> Dict[str, Any]: # pylint: disable=
:param encoded: encoded (serialized) inputs
:return: The decoded input args
"""
return serialize.deserialize(encoded)
return serialize.deserialize_unsafe(encoded)

def update_node_state(self, state: plumpy.process_states.State) -> None:
self.update_outputs()
Expand Down
2 changes: 1 addition & 1 deletion aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def create_communicator(
if with_orm:
from aiida.orm.utils import serialize
encoder = functools.partial(serialize.serialize, encoding='utf-8')
decoder = serialize.deserialize
decoder = serialize.deserialize_unsafe
else:
# used by verdi status to get a communicator without needing to load the dbenv
from aiida.common import json
Expand Down
13 changes: 6 additions & 7 deletions aiida/orm/utils/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,11 @@ def represent_data(self, data):
return super().represent_data(data)


class AiiDALoader(yaml.FullLoader):
class AiiDALoader(yaml.UnsafeLoader):
"""AiiDA specific yaml loader

.. note:: we subclass the `FullLoader` which is the one that since `pyyaml>=5.1` is the loader that prevents
arbitrary code execution. Even though this is in principle only used internally, one could imagine someone
sharing a database with a maliciously crafted process instance dump, which when reloaded could execute arbitrary
code. This load prevents this: https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation
.. note:: The `AiiDALoader` should only be used on trusted input, because it uses the `yaml.UnsafeLoader`. When
importing a shared database, we strip all process node checkpoints to avoid this being a security risk.
"""


Expand Down Expand Up @@ -219,10 +217,11 @@ def serialize(data, encoding=None):
return serialized


def deserialize(serialized):
def deserialize_unsafe(serialized):
"""Deserialize a yaml dump that represents a serialized data structure.

.. note:: no need to use `yaml.safe_load` here because the `Loader` will ensure that loading is safe.
.. note:: This function should not be used on untrusted input, because
it is built upon `yaml.UnsafeLoader`.

:param serialized: a yaml serialized string representation
:return: the deserialized data structure
Expand Down
15 changes: 14 additions & 1 deletion aiida/tools/importexport/dbimport/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from aiida.common import timezone
from aiida.common.progress_reporter import get_progress_reporter, create_callback
from aiida.orm import Group, ImportGroup, Node, QueryBuilder
from aiida.orm import Group, ImportGroup, Node, QueryBuilder, ProcessNode
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract
from aiida.tools.importexport.common import exceptions
from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER
Expand Down Expand Up @@ -127,3 +127,16 @@ def _sanitize_extras(fields: dict) -> dict:
if fields.get('node_type', '').endswith('code.Code.'):
fields['extras'] = {key: value for key, value in fields['extras'].items() if not key == 'hidden'}
return fields


def _strip_checkpoints(fields: dict) -> dict:
"""Remove checkpoint from attributes of process nodes.

:param fields: the database fields for the entity
"""
if fields.get('node_type', '').startswith('process.'):
fields = copy.copy(fields)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the copy needed here? And if it is needed, shouldn't we use deepcopy because we are manipulating a key inside a nested dictionary so the fields['attribute'] will still have a reference to the original object, wouldn't it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the copy needed here?

TBH, I'm not sure why the copy is needed -- that is simply copy/pasted from the _sanitize_extras above. I don't have enough context on the whole import procedure to know if it's necessary.

the fields['attribute'] will still have a reference to the original object, wouldn't it

Right, good catch. Maybe a complete deepcopy is a bit overkill, but we could also copy.copy the fields['attributes'].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're now reconstructing the attributes dict, this should be resolved - but feel free to double-check @sphuber.

fields['attributes'] = {
key: value for key, value in fields['attributes'].items() if key != ProcessNode.CHECKPOINT_KEY
}
return fields
4 changes: 3 additions & 1 deletion aiida/tools/importexport/dbimport/backends/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader

from aiida.tools.importexport.dbimport.backends.common import (
_copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS
_copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS
)


Expand Down Expand Up @@ -355,6 +355,8 @@ def _select_entity_data(
if entity_name == NODE_ENTITY_NAME:
# format extras
fields = _sanitize_extras(fields)
# strip checkpoints
fields = _strip_checkpoints(fields)
if extras_mode_new != 'import':
fields.pop('extras', None)
new_entries[entity_name][str(pk)] = fields
Expand Down
4 changes: 3 additions & 1 deletion aiida/tools/importexport/dbimport/backends/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader

from aiida.tools.importexport.dbimport.backends.common import (
_copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS
_copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS
)


Expand Down Expand Up @@ -392,6 +392,8 @@ def _select_entity_data(
if entity_name == NODE_ENTITY_NAME:
# format extras
fields = _sanitize_extras(fields)
# strip checkpoints
fields = _strip_checkpoints(fields)
if extras_mode_new != 'import':
fields.pop('extras', None)
new_entries[entity_name][str(pk)] = fields
Expand Down
4 changes: 2 additions & 2 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ py:class yaml.Dumper
py:class yaml.Loader
py:class yaml.dumper.Dumper
py:class yaml.loader.Loader
py:class yaml.FullLoader
py:class yaml.loader.FullLoader
py:class yaml.UnsafeLoader
py:class yaml.loader.UnsafeLoader
py:class uuid.UUID

py:class psycopg2.extensions.cursor
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- psycopg2-binary>=2.8.3,~=2.8
- python-dateutil~=2.8
- pytz~=2019.3
- pyyaml~=5.1.2
- pyyaml~=5.1
- reentry~=1.3
- simplejson~=3.16
- sqlalchemy-utils~=0.36.0
Expand Down
2 changes: 1 addition & 1 deletion setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"psycopg2-binary~=2.8,>=2.8.3",
"python-dateutil~=2.8",
"pytz~=2019.3",
"pyyaml~=5.1.2",
"pyyaml~=5.1",
"reentry~=1.3",
"simplejson~=3.16",
"sqlalchemy-utils~=0.36.0",
Expand Down
38 changes: 32 additions & 6 deletions tests/common/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
###########################################################################
"""Serialization tests"""

import types

import numpy as np

from aiida import orm
from aiida.orm.utils import serialize
from aiida.backends.testbase import AiidaTestCase
Expand All @@ -28,7 +32,7 @@ def test_serialize_round_trip(self):
data = {'test': 1, 'list': [1, 2, 3, node_a], 'dict': {('Si',): node_b, 'foo': 'bar'}, 'baz': 'aar'}

serialized_data = serialize.serialize(data)
deserialized_data = serialize.deserialize(serialized_data)
deserialized_data = serialize.deserialize_unsafe(serialized_data)

# For now manual element-for-element comparison until we come up with general
# purpose function that can equate two node instances properly
Expand All @@ -49,29 +53,29 @@ def test_serialize_group(self):
data = {'group': group_a}

serialized_data = serialize.serialize(data)
deserialized_data = serialize.deserialize(serialized_data)
deserialized_data = serialize.deserialize_unsafe(serialized_data)

self.assertEqual(data['group'].uuid, deserialized_data['group'].uuid)
self.assertEqual(data['group'].label, deserialized_data['group'].label)

def test_serialize_node_round_trip(self):
"""Test you can serialize and deserialize a node"""
node = orm.Data().store()
deserialized = serialize.deserialize(serialize.serialize(node))
deserialized = serialize.deserialize_unsafe(serialize.serialize(node))
self.assertEqual(node.uuid, deserialized.uuid)

def test_serialize_group_round_trip(self):
"""Test you can serialize and deserialize a group"""
group = orm.Group(label='test_serialize_group_round_trip').store()
deserialized = serialize.deserialize(serialize.serialize(group))
deserialized = serialize.deserialize_unsafe(serialize.serialize(group))

self.assertEqual(group.uuid, deserialized.uuid)
self.assertEqual(group.label, deserialized.label)

def test_serialize_computer_round_trip(self):
"""Test you can serialize and deserialize a computer"""
computer = self.computer
deserialized = serialize.deserialize(serialize.serialize(computer))
deserialized = serialize.deserialize_unsafe(serialize.serialize(computer))

# pylint: disable=no-member
self.assertEqual(computer.uuid, deserialized.uuid)
Expand Down Expand Up @@ -117,6 +121,28 @@ def test_mixed_attribute_normal_dict(self):
attribute_dict['nested']['normal'] = {'a': 2}

serialized = serialize.serialize(attribute_dict)
deserialized = serialize.deserialize(serialized)
deserialized = serialize.deserialize_unsafe(serialized)

self.assertEqual(attribute_dict, deserialized)

def test_serialize_numpy(self): # pylint: disable=no-self-use
"""Regression test for #3709

Check that numpy arrays can be serialized.
"""
data = np.array([1, 2, 3])

serialized = serialize.serialize(data)
deserialized = serialize.deserialize_unsafe(serialized)
assert np.all(data == deserialized)

def test_serialize_simplenamespace(self): # pylint: disable=no-self-use
"""Regression test for #3709

Check that `types.SimpleNamespace` can be serialized.
"""
data = types.SimpleNamespace(a=1, b=2.1)

serialized = serialize.serialize(data)
deserialized = serialize.deserialize_unsafe(serialized)
assert data == deserialized
30 changes: 30 additions & 0 deletions tests/tools/importexport/test_specific_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,33 @@ def test_cycle_structure_data(self):
builder.append(orm.RemoteData, project=['uuid'], with_incoming='parent', tag='remote')
builder.append(orm.CalculationNode, with_incoming='remote')
self.assertGreater(len(builder.all()), 0)

def test_import_checkpoints(self):
"""Check that process node checkpoints are stripped when importing.

The process node checkpoints need to be stripped because they
could be maliciously crafted to execute arbitrary code, since
we use the `yaml.UnsafeLoader` to load them.
"""
node = orm.WorkflowNode().store()
node.set_checkpoint(12)
node.seal()
node_uuid = node.uuid
assert node.checkpoint == 12

with tempfile.NamedTemporaryFile() as handle:
nodes = [node]
export(nodes, filename=handle.name, overwrite=True)

# Check that we have the expected number of nodes in the database
self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes))

# Clean the database and verify there are no nodes left
self.clean_db()
assert orm.QueryBuilder().append(orm.Node).count() == 0

import_data(handle.name)

assert orm.QueryBuilder().append(orm.Node).count() == len(nodes)
node_new = orm.load_node(node_uuid)
assert node_new.checkpoint is None