-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enum's are often used data types that currently cannot be used in process inputs, even if the ports are `non_db`, because even the `non_db` inputs are serialized as part of the entire `Process` instance that is serialized when a checkpoint is created. The custom YAML serializer and deserializer that are defined in the `aiida.orm.utils.serialize` module that are used for serializing process instances and their inputs, are updated with a dumper and a loader for `Enum` instances. The full class identifier is generated from the default object loader provided by `plumpy` and concatenated by the `|` sign and the enum's values, it is represented as a YAML scalar. The loader uses the same object loader to load the enum class from the identifier string and reconstruct the enum instance from the serialized value. The original tests were written in `tests/common/test_serialize.py` and are moved to `tests/orm/utils/test_serialize.py` to conform with the standard to mirror the package hierarchy. The tests are converted to `pytest` style and a new test is added by performing a round trip of an enum instance (de)serialization.
- Loading branch information
Showing
3 changed files
with
198 additions
and
149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# -*- coding: utf-8 -*- | ||
########################################################################### | ||
# Copyright (c), The AiiDA team. All rights reserved. # | ||
# This file is part of the AiiDA code. # | ||
# # | ||
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # | ||
# For further information on the license, see the LICENSE.txt file # | ||
# For further information please visit http://www.aiida.net # | ||
########################################################################### | ||
"""Tests for the :mod:`aiida.orm.utils.serialize` module.""" | ||
import types | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from aiida import orm | ||
from aiida.common.links import LinkType | ||
from aiida.orm.utils import serialize | ||
|
||
pytestmark = pytest.mark.usefixtures('clear_database_before_test') | ||
|
||
|
||
def test_serialize_round_trip(): | ||
""" | ||
Test the serialization of a dictionary with Nodes in various data structure | ||
Also make sure that the serialized data is json-serializable | ||
""" | ||
node_a = orm.Data().store() | ||
node_b = orm.Data().store() | ||
|
||
data = {'test': 1, 'list': [1, 2, 3, node_a], 'dict': {('Si',): node_b, 'foo': 'bar'}, 'baz': 'aar'} | ||
|
||
serialized_data = serialize.serialize(data) | ||
deserialized_data = serialize.deserialize_unsafe(serialized_data) | ||
|
||
# For now manual element-for-element comparison until we come up with general | ||
# purpose function that can equate two node instances properly | ||
assert data['test'] == deserialized_data['test'] | ||
assert data['baz'] == deserialized_data['baz'] | ||
assert data['list'][:3] == deserialized_data['list'][:3] | ||
assert data['list'][3].uuid == deserialized_data['list'][3].uuid | ||
assert data['dict'][('Si',)].uuid == deserialized_data['dict'][('Si',)].uuid | ||
|
||
|
||
def test_serialize_group(): | ||
""" | ||
Test that serialization and deserialization of Groups works. | ||
Also make sure that the serialized data is json-serializable | ||
""" | ||
group_name = 'groupie' | ||
group_a = orm.Group(label=group_name).store() | ||
|
||
data = {'group': group_a} | ||
|
||
serialized_data = serialize.serialize(data) | ||
deserialized_data = serialize.deserialize_unsafe(serialized_data) | ||
|
||
assert data['group'].uuid == deserialized_data['group'].uuid | ||
assert data['group'].label == deserialized_data['group'].label | ||
|
||
|
||
def test_serialize_node_round_trip(): | ||
"""Test you can serialize and deserialize a node""" | ||
node = orm.Data().store() | ||
deserialized = serialize.deserialize_unsafe(serialize.serialize(node)) | ||
assert node.uuid == deserialized.uuid | ||
|
||
|
||
def test_serialize_group_round_trip(): | ||
"""Test you can serialize and deserialize a group""" | ||
group = orm.Group(label='test_serialize_group_round_trip').store() | ||
deserialized = serialize.deserialize_unsafe(serialize.serialize(group)) | ||
|
||
assert group.uuid == deserialized.uuid | ||
assert group.label == deserialized.label | ||
|
||
|
||
def test_serialize_computer_round_trip(aiida_localhost): | ||
"""Test you can serialize and deserialize a computer""" | ||
deserialized = serialize.deserialize_unsafe(serialize.serialize(aiida_localhost)) | ||
|
||
# pylint: disable=no-member | ||
assert aiida_localhost.uuid == deserialized.uuid | ||
assert aiida_localhost.label == deserialized.label | ||
|
||
|
||
def test_serialize_unstored_node(): | ||
"""Test that you can't serialize an unstored node""" | ||
node = orm.Data() | ||
|
||
with pytest.raises(ValueError): | ||
serialize.serialize(node) | ||
|
||
|
||
def test_serialize_unstored_group(): | ||
"""Test that you can't serialize an unstored group""" | ||
group = orm.Group(label='test_serialize_unstored_group') | ||
|
||
with pytest.raises(ValueError): | ||
serialize.serialize(group) | ||
|
||
|
||
def test_serialize_unstored_computer(): | ||
"""Test that you can't serialize an unstored node""" | ||
computer = orm.Computer('test_computer', 'test_host') | ||
|
||
with pytest.raises(ValueError): | ||
serialize.serialize(computer) | ||
|
||
|
||
def test_mixed_attribute_normal_dict(): | ||
"""Regression test for #3092. | ||
The yaml mapping constructor in `aiida.orm.utils.serialize` was not properly "deeply" reconstructing nested | ||
mappings, causing a mix of attribute dictionaries and normal dictionaries to lose information in a round-trip. | ||
If a nested `AttributeDict` contained a normal dictionary, the content of the latter would be lost during the | ||
deserialization, despite the information being present in the serialized yaml dump. | ||
""" | ||
from aiida.common.extendeddicts import AttributeDict | ||
|
||
# Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively | ||
dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})} | ||
attribute_dict = AttributeDict(dictionary) | ||
|
||
# Now add a normal dictionary in the attribute dictionary | ||
attribute_dict['nested']['normal'] = {'a': 2} | ||
|
||
serialized = serialize.serialize(attribute_dict) | ||
deserialized = serialize.deserialize_unsafe(serialized) | ||
|
||
assert attribute_dict, deserialized | ||
|
||
|
||
def test_serialize_numpy(): | ||
"""Regression test for #3709 | ||
Check that numpy arrays can be serialized. | ||
""" | ||
data = np.array([1, 2, 3]) | ||
|
||
serialized = serialize.serialize(data) | ||
deserialized = serialize.deserialize_unsafe(serialized) | ||
assert np.all(data == deserialized) | ||
|
||
|
||
def test_serialize_simplenamespace(): | ||
"""Regression test for #3709 | ||
Check that `types.SimpleNamespace` can be serialized. | ||
""" | ||
data = types.SimpleNamespace(a=1, b=2.1) | ||
|
||
serialized = serialize.serialize(data) | ||
deserialized = serialize.deserialize_unsafe(serialized) | ||
assert data == deserialized | ||
|
||
|
||
def test_enum(): | ||
"""Test serialization and deserialization of an ``Enum``.""" | ||
enum = LinkType.RETURN | ||
serialized = serialize.serialize(enum) | ||
assert isinstance(serialized, str) | ||
|
||
deserialized = serialize.deserialize_unsafe(serialized) | ||
assert deserialized == enum |