From d89f44190091cbe31ce279ec7c438de7bfd63f0c Mon Sep 17 00:00:00 2001 From: Graham Gower Date: Sat, 4 Dec 2021 11:51:31 +0100 Subject: [PATCH] Add toplevel metadata. Closes #275. --- demes/demes.py | 18 ++++++++++++++++-- setup.cfg | 1 + tests/test_demes.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/demes/demes.py b/demes/demes.py index c5259f21..e1e429be 100644 --- a/demes/demes.py +++ b/demes/demes.py @@ -1,9 +1,10 @@ -from typing import List, Union, Optional, Dict, MutableMapping, Any, Set, Tuple +import copy +import collections import itertools import math import numbers -import copy import operator +from typing import List, Union, Optional, Dict, MutableMapping, Any, Set, Tuple import warnings import attr @@ -1231,6 +1232,7 @@ class Graph: See also: :meth:`.in_generations`. :ivar list[str] doi: If the graph describes a published demography, the DOI(s) should be be given here as a list. + :ivar dict metadata: A dictionary of arbitrary additional data. :ivar list[Deme] demes: The demes in the demography. :ivar list[AsymmetricMigration] migrations: The continuous migrations for the demographic model. @@ -1257,6 +1259,12 @@ class Graph: iterable_validator=attr.validators.instance_of(list), ), ) + metadata: collections.abc.Mapping = attr.ib( + factory=dict, + validator=attr.validators.instance_of( + collections.abc.Mapping # type: ignore[misc] + ), + ) demes: List[Deme] = attr.ib(factory=list, init=False) migrations: List[AsymmetricMigration] = attr.ib(factory=list, init=False) pulses: List[Pulse] = attr.ib(factory=list, init=False) @@ -1919,6 +1927,7 @@ def fromdict(cls, data: MutableMapping[str, Any]) -> "Graph": "generation_time", "defaults", "doi", + "metadata", "demes", "migrations", "pulses", @@ -1980,6 +1989,7 @@ def fromdict(cls, data: MutableMapping[str, Any]) -> "Graph": time_units=data.pop("time_units"), doi=data.pop("doi", []), generation_time=data.pop("generation_time", None), + metadata=data.pop("metadata", {}), ) for i, deme_data in enumerate( @@ -2315,6 +2325,7 @@ def __init__( generation_time: float = None, doi: list = None, defaults: dict = None, + metadata: dict = None, ): """ :param str description: A human readable description of the demography. @@ -2329,6 +2340,7 @@ def __init__( :param doi: If the graph describes a published demography, the DOI(s) should be be given here as a list. :type doi: list[str] + :param dict metadata: A dictionary of arbitrary additional data. """ self.data: MutableMapping[str, Any] = dict(time_units=time_units) if description is not None: @@ -2339,6 +2351,8 @@ def __init__( self.data["doi"] = doi if defaults is not None: self.data["defaults"] = defaults + if metadata is not None: + self.data["metadata"] = metadata def add_deme( self, diff --git a/setup.cfg b/setup.cfg index bec2e0c0..3303ced5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,7 @@ per-file-ignores = [mypy] files = demes, tests +warn_unused_ignores = True [mypy-numpy.*] ignore_missing_imports = True diff --git a/tests/test_demes.py b/tests/test_demes.py index 035ae7d6..6d2400f9 100644 --- a/tests/test_demes.py +++ b/tests/test_demes.py @@ -1824,6 +1824,20 @@ def test_bad_doi(self): doi=[""], ) + def test_metadata_empty(self): + graph = Graph(time_units="generations") + assert graph.metadata == {} + + def test_metadata_simple(self): + metadata = dict(one=1, two="string", three=dict(four=[4, 4, 4, 4])) + graph = Graph(time_units="generations", metadata=metadata) + assert graph.metadata == metadata + + @pytest.mark.parametrize("metadata", [None, 1, "string", [1, 2, 3]]) + def test_bad_metadata(self, metadata): + with pytest.raises(TypeError): + Graph(time_units="generations", metadata=metadata) + @pytest.mark.parametrize("graph", tests.example_graphs()) def test_in_generations(self, graph): dg1 = copy.deepcopy(graph) @@ -4233,3 +4247,23 @@ def test_infinities_in_defaults(self): assert g.demes[0].start_time == math.inf assert g.demes[1].start_time == math.inf assert g.migrations[0].start_time == math.inf + + def test_metadata_empty(self): + b = Builder() + b.add_deme("a", epochs=[dict(start_size=1)]) + graph = b.resolve() + assert graph.metadata == {} + + def test_metadata_simple(self): + metadata = dict(one=1, two="string", three=dict(four=[4, 4, 4, 4])) + b = Builder(metadata=metadata) + b.add_deme("a", epochs=[dict(start_size=1)]) + graph = b.resolve() + assert graph.metadata == metadata + + @pytest.mark.parametrize("metadata", [1, "string", [1, 2, 3]]) + def test_bad_metadata(self, metadata): + b = Builder(metadata=metadata) + b.add_deme("a", epochs=[dict(start_size=1)]) + with pytest.raises(TypeError): + b.resolve()