From 1bc9dbe43ff31b737ce29aca605ae15985d38ca0 Mon Sep 17 00:00:00 2001 From: Dominik Gresch Date: Wed, 16 Jun 2021 10:09:50 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Allow=20numpy=20array?= =?UTF-8?q?s=20to=20be=20serialized=20on=20process=20checkpoints=20(#4730)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To allow objects such as numpy arrays to be serialized to a process checkpoint, the `AiiDALoader` now inherits from `yaml.UnsafeLoader` instead of `yaml.FullLoader`. Note, this change represents a potential security risk, whereby maliciously crafted code could be added to the serialized data and then loaded upon importing an archive. To mitigate this risk, the function `deserialize` has been renamed to `deserialize_unsafe`, and node checkpoint attributes are removed before importing an archive. This code is not part of the public API, and so we assume no specific deprecations are required. This change has also allowed for a relaxation of the `pyaml` pinning (to 5.2), although it should be noted that this upgrade will not be realised until a similar relaxation is implemented in plumpy. --- aiida/engine/persistence.py | 2 +- aiida/engine/processes/process.py | 2 +- aiida/manage/manager.py | 2 +- aiida/orm/utils/serialize.py | 13 +++---- .../importexport/dbimport/backends/common.py | 15 +++++++- .../importexport/dbimport/backends/django.py | 4 +- .../importexport/dbimport/backends/sqla.py | 4 +- docs/source/nitpick-exceptions | 4 +- environment.yml | 2 +- setup.json | 2 +- tests/common/test_serialize.py | 38 ++++++++++++++++--- .../importexport/test_specific_import.py | 30 +++++++++++++++ 12 files changed, 95 insertions(+), 23 deletions(-) diff --git a/aiida/engine/persistence.py b/aiida/engine/persistence.py index 2ccdac03c1..5ee9970b14 100644 --- a/aiida/engine/persistence.py +++ b/aiida/engine/persistence.py @@ -121,7 +121,7 @@ def load_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> plumpy.pe raise PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint') try: - bundle = serialize.deserialize(checkpoint) + bundle = serialize.deserialize_unsafe(checkpoint) except Exception: raise PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}') diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index 12d4d9dc6c..3064bfe75b 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -604,7 +604,7 @@ def decode_input_args(self, encoded: str) -> Dict[str, Any]: # pylint: disable= :param encoded: encoded (serialized) inputs :return: The decoded input args """ - return serialize.deserialize(encoded) + return serialize.deserialize_unsafe(encoded) def update_node_state(self, state: plumpy.process_states.State) -> None: self.update_outputs() diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 351635426a..0e0c94c2da 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -264,7 +264,7 @@ def create_communicator( if with_orm: from aiida.orm.utils import serialize encoder = functools.partial(serialize.serialize, encoding='utf-8') - decoder = serialize.deserialize + decoder = serialize.deserialize_unsafe else: # used by verdi status to get a communicator without needing to load the dbenv from aiida.common import json diff --git a/aiida/orm/utils/serialize.py b/aiida/orm/utils/serialize.py index ae1ebf49dc..ba7fdd85ad 100644 --- a/aiida/orm/utils/serialize.py +++ b/aiida/orm/utils/serialize.py @@ -176,13 +176,11 @@ def represent_data(self, data): return super().represent_data(data) -class AiiDALoader(yaml.FullLoader): +class AiiDALoader(yaml.UnsafeLoader): """AiiDA specific yaml loader - .. note:: we subclass the `FullLoader` which is the one that since `pyyaml>=5.1` is the loader that prevents - arbitrary code execution. Even though this is in principle only used internally, one could imagine someone - sharing a database with a maliciously crafted process instance dump, which when reloaded could execute arbitrary - code. This load prevents this: https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation + .. note:: The `AiiDALoader` should only be used on trusted input, because it uses the `yaml.UnsafeLoader`. When + importing a shared database, we strip all process node checkpoints to avoid this being a security risk. """ @@ -219,10 +217,11 @@ def serialize(data, encoding=None): return serialized -def deserialize(serialized): +def deserialize_unsafe(serialized): """Deserialize a yaml dump that represents a serialized data structure. - .. note:: no need to use `yaml.safe_load` here because the `Loader` will ensure that loading is safe. + .. note:: This function should not be used on untrusted input, because + it is built upon `yaml.UnsafeLoader`. :param serialized: a yaml serialized string representation :return: the deserialized data structure diff --git a/aiida/tools/importexport/dbimport/backends/common.py b/aiida/tools/importexport/dbimport/backends/common.py index 74916ec5bd..7d550a1baf 100644 --- a/aiida/tools/importexport/dbimport/backends/common.py +++ b/aiida/tools/importexport/dbimport/backends/common.py @@ -13,7 +13,7 @@ from aiida.common import timezone from aiida.common.progress_reporter import get_progress_reporter, create_callback -from aiida.orm import Group, ImportGroup, Node, QueryBuilder +from aiida.orm import Group, ImportGroup, Node, QueryBuilder, ProcessNode from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.dbimport.utils import IMPORT_LOGGER @@ -127,3 +127,16 @@ def _sanitize_extras(fields: dict) -> dict: if fields.get('node_type', '').endswith('code.Code.'): fields['extras'] = {key: value for key, value in fields['extras'].items() if not key == 'hidden'} return fields + + +def _strip_checkpoints(fields: dict) -> dict: + """Remove checkpoint from attributes of process nodes. + + :param fields: the database fields for the entity + """ + if fields.get('node_type', '').startswith('process.'): + fields = copy.copy(fields) + fields['attributes'] = { + key: value for key, value in fields['attributes'].items() if key != ProcessNode.CHECKPOINT_KEY + } + return fields diff --git a/aiida/tools/importexport/dbimport/backends/django.py b/aiida/tools/importexport/dbimport/backends/django.py index a63bfc370e..d229382534 100644 --- a/aiida/tools/importexport/dbimport/backends/django.py +++ b/aiida/tools/importexport/dbimport/backends/django.py @@ -32,7 +32,7 @@ from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader from aiida.tools.importexport.dbimport.backends.common import ( - _copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS + _copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS ) @@ -355,6 +355,8 @@ def _select_entity_data( if entity_name == NODE_ENTITY_NAME: # format extras fields = _sanitize_extras(fields) + # strip checkpoints + fields = _strip_checkpoints(fields) if extras_mode_new != 'import': fields.pop('extras', None) new_entries[entity_name][str(pk)] = fields diff --git a/aiida/tools/importexport/dbimport/backends/sqla.py b/aiida/tools/importexport/dbimport/backends/sqla.py index e14aa60e15..aaaf425eee 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla.py +++ b/aiida/tools/importexport/dbimport/backends/sqla.py @@ -38,7 +38,7 @@ from aiida.tools.importexport.archive.readers import ArchiveReaderAbstract, get_reader from aiida.tools.importexport.dbimport.backends.common import ( - _copy_node_repositories, _make_import_group, _sanitize_extras, MAX_COMPUTERS, MAX_GROUPS + _copy_node_repositories, _make_import_group, _sanitize_extras, _strip_checkpoints, MAX_COMPUTERS, MAX_GROUPS ) @@ -392,6 +392,8 @@ def _select_entity_data( if entity_name == NODE_ENTITY_NAME: # format extras fields = _sanitize_extras(fields) + # strip checkpoints + fields = _strip_checkpoints(fields) if extras_mode_new != 'import': fields.pop('extras', None) new_entries[entity_name][str(pk)] = fields diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 8978b61951..3dae986e7d 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -139,8 +139,8 @@ py:class yaml.Dumper py:class yaml.Loader py:class yaml.dumper.Dumper py:class yaml.loader.Loader -py:class yaml.FullLoader -py:class yaml.loader.FullLoader +py:class yaml.UnsafeLoader +py:class yaml.loader.UnsafeLoader py:class uuid.UUID py:class psycopg2.extensions.cursor diff --git a/environment.yml b/environment.yml index 85b0d42889..5a73076314 100644 --- a/environment.yml +++ b/environment.yml @@ -31,7 +31,7 @@ dependencies: - psycopg2-binary>=2.8.3,~=2.8 - python-dateutil~=2.8 - pytz~=2019.3 -- pyyaml~=5.1.2 +- pyyaml~=5.1 - reentry~=1.3 - simplejson~=3.16 - sqlalchemy-utils~=0.36.0 diff --git a/setup.json b/setup.json index 28f0ad9e48..d7cf07044b 100644 --- a/setup.json +++ b/setup.json @@ -45,7 +45,7 @@ "psycopg2-binary~=2.8,>=2.8.3", "python-dateutil~=2.8", "pytz~=2019.3", - "pyyaml~=5.1.2", + "pyyaml~=5.1", "reentry~=1.3", "simplejson~=3.16", "sqlalchemy-utils~=0.36.0", diff --git a/tests/common/test_serialize.py b/tests/common/test_serialize.py index 720456678f..3b86b0cb58 100644 --- a/tests/common/test_serialize.py +++ b/tests/common/test_serialize.py @@ -9,6 +9,10 @@ ########################################################################### """Serialization tests""" +import types + +import numpy as np + from aiida import orm from aiida.orm.utils import serialize from aiida.backends.testbase import AiidaTestCase @@ -28,7 +32,7 @@ def test_serialize_round_trip(self): data = {'test': 1, 'list': [1, 2, 3, node_a], 'dict': {('Si',): node_b, 'foo': 'bar'}, 'baz': 'aar'} serialized_data = serialize.serialize(data) - deserialized_data = serialize.deserialize(serialized_data) + deserialized_data = serialize.deserialize_unsafe(serialized_data) # For now manual element-for-element comparison until we come up with general # purpose function that can equate two node instances properly @@ -49,7 +53,7 @@ def test_serialize_group(self): data = {'group': group_a} serialized_data = serialize.serialize(data) - deserialized_data = serialize.deserialize(serialized_data) + deserialized_data = serialize.deserialize_unsafe(serialized_data) self.assertEqual(data['group'].uuid, deserialized_data['group'].uuid) self.assertEqual(data['group'].label, deserialized_data['group'].label) @@ -57,13 +61,13 @@ def test_serialize_group(self): def test_serialize_node_round_trip(self): """Test you can serialize and deserialize a node""" node = orm.Data().store() - deserialized = serialize.deserialize(serialize.serialize(node)) + deserialized = serialize.deserialize_unsafe(serialize.serialize(node)) self.assertEqual(node.uuid, deserialized.uuid) def test_serialize_group_round_trip(self): """Test you can serialize and deserialize a group""" group = orm.Group(label='test_serialize_group_round_trip').store() - deserialized = serialize.deserialize(serialize.serialize(group)) + deserialized = serialize.deserialize_unsafe(serialize.serialize(group)) self.assertEqual(group.uuid, deserialized.uuid) self.assertEqual(group.label, deserialized.label) @@ -71,7 +75,7 @@ def test_serialize_group_round_trip(self): def test_serialize_computer_round_trip(self): """Test you can serialize and deserialize a computer""" computer = self.computer - deserialized = serialize.deserialize(serialize.serialize(computer)) + deserialized = serialize.deserialize_unsafe(serialize.serialize(computer)) # pylint: disable=no-member self.assertEqual(computer.uuid, deserialized.uuid) @@ -117,6 +121,28 @@ def test_mixed_attribute_normal_dict(self): attribute_dict['nested']['normal'] = {'a': 2} serialized = serialize.serialize(attribute_dict) - deserialized = serialize.deserialize(serialized) + deserialized = serialize.deserialize_unsafe(serialized) self.assertEqual(attribute_dict, deserialized) + + def test_serialize_numpy(self): # pylint: disable=no-self-use + """Regression test for #3709 + + Check that numpy arrays can be serialized. + """ + data = np.array([1, 2, 3]) + + serialized = serialize.serialize(data) + deserialized = serialize.deserialize_unsafe(serialized) + assert np.all(data == deserialized) + + def test_serialize_simplenamespace(self): # pylint: disable=no-self-use + """Regression test for #3709 + + Check that `types.SimpleNamespace` can be serialized. + """ + data = types.SimpleNamespace(a=1, b=2.1) + + serialized = serialize.serialize(data) + deserialized = serialize.deserialize_unsafe(serialized) + assert data == deserialized diff --git a/tests/tools/importexport/test_specific_import.py b/tests/tools/importexport/test_specific_import.py index 26bb746a17..a1808cc26f 100644 --- a/tests/tools/importexport/test_specific_import.py +++ b/tests/tools/importexport/test_specific_import.py @@ -154,3 +154,33 @@ def test_cycle_structure_data(self): builder.append(orm.RemoteData, project=['uuid'], with_incoming='parent', tag='remote') builder.append(orm.CalculationNode, with_incoming='remote') self.assertGreater(len(builder.all()), 0) + + def test_import_checkpoints(self): + """Check that process node checkpoints are stripped when importing. + + The process node checkpoints need to be stripped because they + could be maliciously crafted to execute arbitrary code, since + we use the `yaml.UnsafeLoader` to load them. + """ + node = orm.WorkflowNode().store() + node.set_checkpoint(12) + node.seal() + node_uuid = node.uuid + assert node.checkpoint == 12 + + with tempfile.NamedTemporaryFile() as handle: + nodes = [node] + export(nodes, filename=handle.name, overwrite=True) + + # Check that we have the expected number of nodes in the database + self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) + + # Clean the database and verify there are no nodes left + self.clean_db() + assert orm.QueryBuilder().append(orm.Node).count() == 0 + + import_data(handle.name) + + assert orm.QueryBuilder().append(orm.Node).count() == len(nodes) + node_new = orm.load_node(node_uuid) + assert node_new.checkpoint is None