Skip to content

Commit

Permalink
Add support for Enum types to aiida.orm.utils.serialize (#5218)
Browse files Browse the repository at this point in the history
Enum's are often used data types that currently cannot be used in
process inputs, even if the ports are `non_db`, because even the
`non_db` inputs are serialized as part of the entire `Process` instance
that is serialized when a checkpoint is created.

The custom YAML serializer and deserializer that are defined in the
`aiida.orm.utils.serialize` module that are used for serializing process
instances and their inputs, are updated with a dumper and a loader for
`Enum` instances. The full class identifier is generated from the
default object loader provided by `plumpy` and concatenated by the `|`
sign and the enum's values, it is represented as a YAML scalar.

The loader uses the same object loader to load the enum class from the
identifier string and reconstruct the enum instance from the serialized
value.

The original tests were written in `tests/common/test_serialize.py` and
are moved to `tests/orm/utils/test_serialize.py` to conform with the
standard to mirror the package hierarchy. The tests are converted to
`pytest` style and a new test is added by performing a round trip of an
enum instance (de)serialization.
  • Loading branch information
sphuber authored Nov 5, 2021
1 parent 7cea5be commit 8acdc24
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 149 deletions.
33 changes: 32 additions & 1 deletion aiida/orm/utils/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
checkpoints and messages in the RabbitMQ queue so do so with caution. It is fine to add representers
for new types though.
"""
from enum import Enum
from functools import partial

from plumpy import Bundle
from plumpy import Bundle, get_object_loader
from plumpy.utils import AttributesFrozendict
import yaml

from aiida import orm
from aiida.common import AttributeDict

_ENUM_TAG = '!enum'
_NODE_TAG = '!aiida_node'
_GROUP_TAG = '!aiida_group'
_COMPUTER_TAG = '!aiida_computer'
Expand All @@ -31,6 +33,33 @@
_PLUMPY_BUNDLE = '!plumpy:bundle'


def represent_enum(dumper, enum):
"""Represent an arbitrary enum in yaml.
:param dumper: the dumper to use.
:type dumper: :class:`yaml.dumper.Dumper`
:param bundle: the bundle to represent
:return: the representation
"""
loader = get_object_loader()
return dumper.represent_scalar(_ENUM_TAG, f'{loader.identify_object(enum)}|{enum.value}')


def enum_constructor(loader, serialized):
"""Construct an enum from the serialized representation.
:param loader: the yaml loader.
:type loader: :class:`yaml.loader.Loader`
:param bundle: the enum representation.
:return: the enum.
"""
deserialized = loader.construct_scalar(serialized)
identifier, value = deserialized.split('|')
cls = get_object_loader().load_object(identifier)
enum = cls(value)
return enum


def represent_node(dumper, node):
"""Represent a node in yaml.
Expand Down Expand Up @@ -184,6 +213,7 @@ class AiiDALoader(yaml.Loader):
"""


yaml.add_representer(Enum, represent_enum, Dumper=AiiDADumper)
yaml.add_representer(Bundle, represent_bundle, Dumper=AiiDADumper)
yaml.add_representer(AttributeDict, partial(represent_mapping, _ATTRIBUTE_DICT_TAG), Dumper=AiiDADumper)
yaml.add_constructor(_ATTRIBUTE_DICT_TAG, partial(mapping_constructor, AttributeDict), Loader=AiiDALoader)
Expand All @@ -197,6 +227,7 @@ class AiiDALoader(yaml.Loader):
yaml.add_constructor(_NODE_TAG, node_constructor, Loader=AiiDALoader)
yaml.add_constructor(_GROUP_TAG, group_constructor, Loader=AiiDALoader)
yaml.add_constructor(_COMPUTER_TAG, computer_constructor, Loader=AiiDALoader)
yaml.add_constructor(_ENUM_TAG, enum_constructor, Loader=AiiDALoader)


def serialize(data, encoding=None):
Expand Down
148 changes: 0 additions & 148 deletions tests/common/test_serialize.py

This file was deleted.

166 changes: 166 additions & 0 deletions tests/orm/utils/test_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the :mod:`aiida.orm.utils.serialize` module."""
import types

import numpy as np
import pytest

from aiida import orm
from aiida.common.links import LinkType
from aiida.orm.utils import serialize

pytestmark = pytest.mark.usefixtures('clear_database_before_test')


def test_serialize_round_trip():
"""
Test the serialization of a dictionary with Nodes in various data structure
Also make sure that the serialized data is json-serializable
"""
node_a = orm.Data().store()
node_b = orm.Data().store()

data = {'test': 1, 'list': [1, 2, 3, node_a], 'dict': {('Si',): node_b, 'foo': 'bar'}, 'baz': 'aar'}

serialized_data = serialize.serialize(data)
deserialized_data = serialize.deserialize_unsafe(serialized_data)

# For now manual element-for-element comparison until we come up with general
# purpose function that can equate two node instances properly
assert data['test'] == deserialized_data['test']
assert data['baz'] == deserialized_data['baz']
assert data['list'][:3] == deserialized_data['list'][:3]
assert data['list'][3].uuid == deserialized_data['list'][3].uuid
assert data['dict'][('Si',)].uuid == deserialized_data['dict'][('Si',)].uuid


def test_serialize_group():
"""
Test that serialization and deserialization of Groups works.
Also make sure that the serialized data is json-serializable
"""
group_name = 'groupie'
group_a = orm.Group(label=group_name).store()

data = {'group': group_a}

serialized_data = serialize.serialize(data)
deserialized_data = serialize.deserialize_unsafe(serialized_data)

assert data['group'].uuid == deserialized_data['group'].uuid
assert data['group'].label == deserialized_data['group'].label


def test_serialize_node_round_trip():
"""Test you can serialize and deserialize a node"""
node = orm.Data().store()
deserialized = serialize.deserialize_unsafe(serialize.serialize(node))
assert node.uuid == deserialized.uuid


def test_serialize_group_round_trip():
"""Test you can serialize and deserialize a group"""
group = orm.Group(label='test_serialize_group_round_trip').store()
deserialized = serialize.deserialize_unsafe(serialize.serialize(group))

assert group.uuid == deserialized.uuid
assert group.label == deserialized.label


def test_serialize_computer_round_trip(aiida_localhost):
"""Test you can serialize and deserialize a computer"""
deserialized = serialize.deserialize_unsafe(serialize.serialize(aiida_localhost))

# pylint: disable=no-member
assert aiida_localhost.uuid == deserialized.uuid
assert aiida_localhost.label == deserialized.label


def test_serialize_unstored_node():
"""Test that you can't serialize an unstored node"""
node = orm.Data()

with pytest.raises(ValueError):
serialize.serialize(node)


def test_serialize_unstored_group():
"""Test that you can't serialize an unstored group"""
group = orm.Group(label='test_serialize_unstored_group')

with pytest.raises(ValueError):
serialize.serialize(group)


def test_serialize_unstored_computer():
"""Test that you can't serialize an unstored node"""
computer = orm.Computer('test_computer', 'test_host')

with pytest.raises(ValueError):
serialize.serialize(computer)


def test_mixed_attribute_normal_dict():
"""Regression test for #3092.
The yaml mapping constructor in `aiida.orm.utils.serialize` was not properly "deeply" reconstructing nested
mappings, causing a mix of attribute dictionaries and normal dictionaries to lose information in a round-trip.
If a nested `AttributeDict` contained a normal dictionary, the content of the latter would be lost during the
deserialization, despite the information being present in the serialized yaml dump.
"""
from aiida.common.extendeddicts import AttributeDict

# Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively
dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}
attribute_dict = AttributeDict(dictionary)

# Now add a normal dictionary in the attribute dictionary
attribute_dict['nested']['normal'] = {'a': 2}

serialized = serialize.serialize(attribute_dict)
deserialized = serialize.deserialize_unsafe(serialized)

assert attribute_dict, deserialized


def test_serialize_numpy():
"""Regression test for #3709
Check that numpy arrays can be serialized.
"""
data = np.array([1, 2, 3])

serialized = serialize.serialize(data)
deserialized = serialize.deserialize_unsafe(serialized)
assert np.all(data == deserialized)


def test_serialize_simplenamespace():
"""Regression test for #3709
Check that `types.SimpleNamespace` can be serialized.
"""
data = types.SimpleNamespace(a=1, b=2.1)

serialized = serialize.serialize(data)
deserialized = serialize.deserialize_unsafe(serialized)
assert data == deserialized


def test_enum():
"""Test serialization and deserialization of an ``Enum``."""
enum = LinkType.RETURN
serialized = serialize.serialize(enum)
assert isinstance(serialized, str)

deserialized = serialize.deserialize_unsafe(serialized)
assert deserialized == enum

0 comments on commit 8acdc24

Please sign in to comment.