Skip to content

Commit

Permalink
👌 IMPROVE: Allow numpy arrays to be serialized on process checkpoints (…
Browse files Browse the repository at this point in the history
…#4730)

To allow objects such as numpy arrays to be serialized to a process checkpoint,
the `AiiDALoader` now inherits from `yaml.UnsafeLoader` instead of `yaml.FullLoader`.

Note, this change represents a potential security risk,
whereby maliciously crafted code could be added to the serialized data
and then loaded upon importing an archive.
To mitigate this risk, the function `deserialize` has been renamed to `deserialize_unsafe`,
and node checkpoint attributes are removed before importing an archive.
This code is not part of the public API,
and so we assume no specific deprecations are required.

This change has also allowed for a relaxation of the `pyaml` pinning (to 5.2),
although it should be noted that this upgrade will not be realised
until a similar relaxation is implemented in plumpy.
  • Loading branch information
greschd authored Jun 16, 2021
1 parent 28ef3a5 commit 1bc9dbe
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 23 deletions.
2 changes: 1 addition & 1 deletion aiida/engine/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def load_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> plumpy.pe
raise PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint')

try:
bundle = serialize.deserialize(checkpoint)
bundle = serialize.deserialize_unsafe(checkpoint)
except Exception:
raise PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}')

Expand Down
2 changes: 1 addition & 1 deletion aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def decode_input_args(self, encoded: str) -> Dict[str, Any]: # pylint: disable=
:param encoded: encoded (serialized) inputs
:return: The decoded input args
"""
return serialize.deserialize(encoded)
return serialize.deserialize_unsafe(encoded)

def update_node_state(self, state: plumpy.process_states.State) -> None:
self.update_outputs()
Expand Down
2 changes: 1 addition & 1 deletion aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def create_communicator(
if with_orm:
from aiida.orm.utils import serialize
encoder = functools.partial(serialize.serialize, encoding='utf-8')
decoder = serialize.deserialize
decoder = serialize.deserialize_unsafe
else:
# used by verdi status to get a communicator without needing to load the dbenv
from aiida.common import json
Expand Down
13 changes: 6 additions & 7 deletions aiida/orm/utils/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,11 @@ def represent_data(self, data):
return super().represent_data(data)


class AiiDALoader(yaml.FullLoader):
class AiiDALoader(yaml.UnsafeLoader):
"""AiiDA specific yaml loader
.. note:: we subclass the `FullLoader` which is the one that since `pyyaml>=5.1` is the loader that prevents
arbitrary code execution. Even though this is in principle only used internally, one could imagine someone
sharing a database with a maliciously crafted process instance dump, which when reloaded could execute arbitrary
code. This load prevents this: https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation
.. note:: The `AiiDALoader` should only be used on trusted input, because it uses the `yaml.UnsafeLoader`. When
importing a shared database, we strip all process node checkpoints to avoid this being a security risk.
"""


Expand Down Expand Up @@ -219,10 +217,11 @@ def serialize(data, encoding=None):
return serialized


def deserialize(serialized):
def deserialize_unsafe(serialized):
"""Deserialize a yaml dump that represents a serialized data structure.
.. note:: no need to use `yaml.safe_load` here because the `Loader` will ensure that loading is safe.
.. note:: This function should not be used on untrusted input, because
it is built upon `yaml.UnsafeLoader`.
:param serialized: a yaml serialized string representation
:return: the deserialized data structure
Expand Down
15 changes: 14 additions & 1 deletion aiida/tools/importexport/dbimport/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from aiida.common import timezone
from aiida.common.progress_reporter import get_progress_reporter, create_callback
from aiida.orm import Group, ImportGroup, Node, QueryBuilder
from aiida.orm import Group, ImportGroup, Node, QueryBuilder, ProcessNode
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract
from aiida.tools.importexport.common import exceptions
from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER
Expand Down Expand Up @@ -127,3 +127,16 @@ def _sanitize_extras(fields: dict) -> dict:
if fields.get('node_type', '').endswith('code.Code.'):
fields['extras'] = {key: value for key, value in fields['extras'].items() if not key == 'hidden'}
return fields


def _strip_checkpoints(fields: dict) -> dict:
"""Remove checkpoint from attributes of process nodes.
:param fields: the database fields for the entity
"""
if fields.get('node_type', '').startswith('process.'):
fields = copy.copy(fields)
fields['attributes'] = {
key: value for key, value in fields['attributes'].items() if key != ProcessNode.CHECKPOINT_KEY
}
return fields
4 changes: 3 additions & 1 deletion aiida/tools/importexport/dbimport/backends/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader

from aiida.tools.importexport.dbimport.backends.common import (
_copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS
_copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS
)


Expand Down Expand Up @@ -355,6 +355,8 @@ def _select_entity_data(
if entity_name == NODE_ENTITY_NAME:
# format extras
fields = _sanitize_extras(fields)
# strip checkpoints
fields = _strip_checkpoints(fields)
if extras_mode_new != 'import':
fields.pop('extras', None)
new_entries[entity_name][str(pk)] = fields
Expand Down
4 changes: 3 additions & 1 deletion aiida/tools/importexport/dbimport/backends/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader

from aiida.tools.importexport.dbimport.backends.common import (
_copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS
_copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS
)


Expand Down Expand Up @@ -392,6 +392,8 @@ def _select_entity_data(
if entity_name == NODE_ENTITY_NAME:
# format extras
fields = _sanitize_extras(fields)
# strip checkpoints
fields = _strip_checkpoints(fields)
if extras_mode_new != 'import':
fields.pop('extras', None)
new_entries[entity_name][str(pk)] = fields
Expand Down
4 changes: 2 additions & 2 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ py:class yaml.Dumper
py:class yaml.Loader
py:class yaml.dumper.Dumper
py:class yaml.loader.Loader
py:class yaml.FullLoader
py:class yaml.loader.FullLoader
py:class yaml.UnsafeLoader
py:class yaml.loader.UnsafeLoader
py:class uuid.UUID

py:class psycopg2.extensions.cursor
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- psycopg2-binary>=2.8.3,~=2.8
- python-dateutil~=2.8
- pytz~=2019.3
- pyyaml~=5.1.2
- pyyaml~=5.1
- reentry~=1.3
- simplejson~=3.16
- sqlalchemy-utils~=0.36.0
Expand Down
2 changes: 1 addition & 1 deletion setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"psycopg2-binary~=2.8,>=2.8.3",
"python-dateutil~=2.8",
"pytz~=2019.3",
"pyyaml~=5.1.2",
"pyyaml~=5.1",
"reentry~=1.3",
"simplejson~=3.16",
"sqlalchemy-utils~=0.36.0",
Expand Down
38 changes: 32 additions & 6 deletions tests/common/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
###########################################################################
"""Serialization tests"""

import types

import numpy as np

from aiida import orm
from aiida.orm.utils import serialize
from aiida.backends.testbase import AiidaTestCase
Expand All @@ -28,7 +32,7 @@ def test_serialize_round_trip(self):
data = {'test': 1, 'list': [1, 2, 3, node_a], 'dict': {('Si',): node_b, 'foo': 'bar'}, 'baz': 'aar'}

serialized_data = serialize.serialize(data)
deserialized_data = serialize.deserialize(serialized_data)
deserialized_data = serialize.deserialize_unsafe(serialized_data)

# For now manual element-for-element comparison until we come up with general
# purpose function that can equate two node instances properly
Expand All @@ -49,29 +53,29 @@ def test_serialize_group(self):
data = {'group': group_a}

serialized_data = serialize.serialize(data)
deserialized_data = serialize.deserialize(serialized_data)
deserialized_data = serialize.deserialize_unsafe(serialized_data)

self.assertEqual(data['group'].uuid, deserialized_data['group'].uuid)
self.assertEqual(data['group'].label, deserialized_data['group'].label)

def test_serialize_node_round_trip(self):
"""Test you can serialize and deserialize a node"""
node = orm.Data().store()
deserialized = serialize.deserialize(serialize.serialize(node))
deserialized = serialize.deserialize_unsafe(serialize.serialize(node))
self.assertEqual(node.uuid, deserialized.uuid)

def test_serialize_group_round_trip(self):
"""Test you can serialize and deserialize a group"""
group = orm.Group(label='test_serialize_group_round_trip').store()
deserialized = serialize.deserialize(serialize.serialize(group))
deserialized = serialize.deserialize_unsafe(serialize.serialize(group))

self.assertEqual(group.uuid, deserialized.uuid)
self.assertEqual(group.label, deserialized.label)

def test_serialize_computer_round_trip(self):
"""Test you can serialize and deserialize a computer"""
computer = self.computer
deserialized = serialize.deserialize(serialize.serialize(computer))
deserialized = serialize.deserialize_unsafe(serialize.serialize(computer))

# pylint: disable=no-member
self.assertEqual(computer.uuid, deserialized.uuid)
Expand Down Expand Up @@ -117,6 +121,28 @@ def test_mixed_attribute_normal_dict(self):
attribute_dict['nested']['normal'] = {'a': 2}

serialized = serialize.serialize(attribute_dict)
deserialized = serialize.deserialize(serialized)
deserialized = serialize.deserialize_unsafe(serialized)

self.assertEqual(attribute_dict, deserialized)

def test_serialize_numpy(self): # pylint: disable=no-self-use
"""Regression test for #3709
Check that numpy arrays can be serialized.
"""
data = np.array([1, 2, 3])

serialized = serialize.serialize(data)
deserialized = serialize.deserialize_unsafe(serialized)
assert np.all(data == deserialized)

def test_serialize_simplenamespace(self): # pylint: disable=no-self-use
"""Regression test for #3709
Check that `types.SimpleNamespace` can be serialized.
"""
data = types.SimpleNamespace(a=1, b=2.1)

serialized = serialize.serialize(data)
deserialized = serialize.deserialize_unsafe(serialized)
assert data == deserialized
30 changes: 30 additions & 0 deletions tests/tools/importexport/test_specific_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,33 @@ def test_cycle_structure_data(self):
builder.append(orm.RemoteData, project=['uuid'], with_incoming='parent', tag='remote')
builder.append(orm.CalculationNode, with_incoming='remote')
self.assertGreater(len(builder.all()), 0)

def test_import_checkpoints(self):
"""Check that process node checkpoints are stripped when importing.
The process node checkpoints need to be stripped because they
could be maliciously crafted to execute arbitrary code, since
we use the `yaml.UnsafeLoader` to load them.
"""
node = orm.WorkflowNode().store()
node.set_checkpoint(12)
node.seal()
node_uuid = node.uuid
assert node.checkpoint == 12

with tempfile.NamedTemporaryFile() as handle:
nodes = [node]
export(nodes, filename=handle.name, overwrite=True)

# Check that we have the expected number of nodes in the database
self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes))

# Clean the database and verify there are no nodes left
self.clean_db()
assert orm.QueryBuilder().append(orm.Node).count() == 0

import_data(handle.name)

assert orm.QueryBuilder().append(orm.Node).count() == len(nodes)
node_new = orm.load_node(node_uuid)
assert node_new.checkpoint is None

0 comments on commit 1bc9dbe

Please sign in to comment.