From 68075046c10c521f0a85f6c82e576ea64f6c5df2 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Fri, 27 Aug 2021 18:49:36 -0400 Subject: [PATCH] Check isomorphism https://github.com/xarray-contrib/datatree/pull/31 * pseudocode ideas for generalizing map_over_subtree * pseudocode for a generalized map_over_subtree (still only one return arg) + a new mapping.py file * pseudocode for mapping but now multiple return values * pseudocode for mapping but with multiple return values * check_isomorphism works and has tests * cleaned up the mapping tests a bit * remove WIP from oter branch * ensure tests pass * map_over_subtree in the public API properly * linting --- datatree/__init__.py | 3 +- datatree/datatree.py | 58 +-------- datatree/mapping.py | 139 ++++++++++++++++++++++ datatree/tests/test_dataset_api.py | 69 +---------- datatree/tests/test_datatree.py | 23 +++- datatree/tests/test_mapping.py | 184 +++++++++++++++++++++++++++++ datatree/treenode.py | 2 +- 7 files changed, 346 insertions(+), 132 deletions(-) create mode 100644 datatree/mapping.py create mode 100644 datatree/tests/test_mapping.py diff --git a/datatree/__init__.py b/datatree/__init__.py index f83edbb..fbe1cba 100644 --- a/datatree/__init__.py +++ b/datatree/__init__.py @@ -1,4 +1,5 @@ # flake8: noqa # Ignoring F401: imported but unused -from .datatree import DataNode, DataTree, map_over_subtree +from .datatree import DataNode, DataTree from .io import open_datatree +from .mapping import map_over_subtree diff --git a/datatree/datatree.py b/datatree/datatree.py index 1bd495d..1828f7c 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools import textwrap from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Union @@ -14,6 +13,7 @@ from xarray.core.ops import NAN_CUM_METHODS, NAN_REDUCE_METHODS, REDUCE_METHODS from xarray.core.variable import Variable +from .mapping import map_over_subtree from .treenode import PathType, TreeNode, _init_single_treenode """ @@ -50,62 +50,6 @@ """ -def map_over_subtree(func): - """ - Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. - - Applies a function to every dataset in this subtree, returning a new tree which stores the results. - - The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the - descendant nodes. The returned tree will have the same structure as the original subtree. - - func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each - result will be assigned to its respective node of new tree via `DataTree.__setitem__`. - - Parameters - ---------- - func : callable - Function to apply to datasets with signature: - `func(node.ds, *args, **kwargs) -> Dataset`. - - Function will not be applied to any nodes without datasets. - *args : tuple, optional - Positional arguments passed on to `func`. - **kwargs : Any - Keyword arguments passed on to `func`. - - Returns - ------- - mapped : callable - Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node. - - See also - -------- - DataTree.map_over_subtree - DataTree.map_over_subtree_inplace - """ - - @functools.wraps(func) - def _map_over_subtree(tree, *args, **kwargs): - """Internal function which maps func over every node in tree, returning a tree of the results.""" - - # Recreate and act on root node - out_tree = DataNode(name=tree.name, data=tree.ds) - if out_tree.has_data: - out_tree.ds = func(out_tree.ds, *args, **kwargs) - - # Act on every other node in the tree, and rebuild from results - for node in tree.descendants: - # TODO make a proper relative_path method - relative_path = node.pathstr.replace(tree.pathstr, "") - result = func(node.ds, *args, **kwargs) if node.has_data else None - out_tree[relative_path] = result - - return out_tree - - return _map_over_subtree - - class DatasetPropertiesMixin: """Expose properties of wrapped Dataset""" diff --git a/datatree/mapping.py b/datatree/mapping.py new file mode 100644 index 0000000..b0ff2b2 --- /dev/null +++ b/datatree/mapping.py @@ -0,0 +1,139 @@ +import functools + +from anytree.iterators import LevelOrderIter + +from .treenode import TreeNode + + +class TreeIsomorphismError(ValueError): + """Error raised if two tree objects are not isomorphic to one another when they need to be.""" + + pass + + +def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False): + """ + Check that two trees have the same structure, raising an error if not. + + Does not check the actual data in the nodes, but it does check that if one node does/doesn't have data then its + counterpart in the other tree also does/doesn't have data. + + Also does not check that the root nodes of each tree have the same parent - so this function checks that subtrees + are isomorphic, not the entire tree above (if it exists). + + Can optionally check if respective nodes should have the same name. + + Parameters + ---------- + subtree_a : DataTree + subtree_b : DataTree + require_names_equal : Bool, optional + Whether or not to also check that each node has the same name as its counterpart. Default is False. + + Raises + ------ + TypeError + If either subtree_a or subtree_b are not tree objects. + TreeIsomorphismError + If subtree_a and subtree_b are tree objects, but are not isomorphic to one another, or one contains data at a + location the other does not. Also optionally raised if their structure is isomorphic, but the names of any two + respective nodes are not equal. + """ + # TODO turn this into a public function called assert_isomorphic + + if not isinstance(subtree_a, TreeNode): + raise TypeError( + f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}" + ) + if not isinstance(subtree_b, TreeNode): + raise TypeError( + f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}" + ) + + # Walking nodes in "level-order" fashion means walking down from the root breadth-first. + # Checking by walking in this way implicitly assumes that the tree is an ordered tree (which it is so long as + # children are stored in a tuple or list rather than in a set). + for node_a, node_b in zip(LevelOrderIter(subtree_a), LevelOrderIter(subtree_b)): + path_a, path_b = node_a.pathstr, node_b.pathstr + + if require_names_equal: + if node_a.name != node_b.name: + raise TreeIsomorphismError( + f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"name '{node_a.name}', whereas its counterpart node '{path_b}' in the " + f"second tree has name '{node_b.name}'." + ) + + if node_a.has_data != node_b.has_data: + dat_a = "no " if not node_a.has_data else "" + dat_b = "no " if not node_b.has_data else "" + raise TreeIsomorphismError( + f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"{dat_a}data, whereas its counterpart node '{path_b}' in the second tree " + f"has {dat_b}data." + ) + + if len(node_a.children) != len(node_b.children): + raise TreeIsomorphismError( + f"Trees are not isomorphic because node '{path_a}' in the first tree has " + f"{len(node_a.children)} children, whereas its counterpart node '{path_b}' in " + f"the second tree has {len(node_b.children)} children." + ) + + +def map_over_subtree(func): + """ + Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees. + + Applies a function to every dataset in this subtree, returning a new tree which stores the results. + + The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the + descendant nodes. The returned tree will have the same structure as the original subtree. + + func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each + result will be assigned to its respective node of new tree via `DataTree.__setitem__`. + + Parameters + ---------- + func : callable + Function to apply to datasets with signature: + `func(node.ds, *args, **kwargs) -> Dataset`. + + Function will not be applied to any nodes without datasets. + *args : tuple, optional + Positional arguments passed on to `func`. + **kwargs : Any + Keyword arguments passed on to `func`. + + Returns + ------- + mapped : callable + Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node. + + See also + -------- + DataTree.map_over_subtree + DataTree.map_over_subtree_inplace + """ + + @functools.wraps(func) + def _map_over_subtree(tree, *args, **kwargs): + """Internal function which maps func over every node in tree, returning a tree of the results.""" + + # Recreate and act on root node + from .datatree import DataNode + + out_tree = DataNode(name=tree.name, data=tree.ds) + if out_tree.has_data: + out_tree.ds = func(out_tree.ds, *args, **kwargs) + + # Act on every other node in the tree, and rebuild from results + for node in tree.descendants: + # TODO make a proper relative_path method + relative_path = node.pathstr.replace(tree.pathstr, "") + result = func(node.ds, *args, **kwargs) if node.has_data else None + out_tree[relative_path] = result + + return out_tree + + return _map_over_subtree diff --git a/datatree/tests/test_dataset_api.py b/datatree/tests/test_dataset_api.py index 82d8871..e930f49 100644 --- a/datatree/tests/test_dataset_api.py +++ b/datatree/tests/test_dataset_api.py @@ -1,76 +1,9 @@ import numpy as np import pytest import xarray as xr -from test_datatree import create_test_datatree from xarray.testing import assert_equal -from datatree import DataNode, DataTree, map_over_subtree - - -class TestMapOverSubTree: - def test_map_over_subtree(self): - dt = create_test_datatree() - - @map_over_subtree - def times_ten(ds): - return 10.0 * ds - - result_tree = times_ten(dt) - - # TODO write an assert_tree_equal function - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, original_node.ds * 10.0) - else: - assert not result_node.has_data - - def test_map_over_subtree_with_args_and_kwargs(self): - dt = create_test_datatree() - - @map_over_subtree - def multiply_then_add(ds, times, add=0.0): - return times * ds + add - - result_tree = multiply_then_add(dt, 10.0, add=2.0) - - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) - else: - assert not result_node.has_data - - def test_map_over_subtree_method(self): - dt = create_test_datatree() - - def multiply_then_add(ds, times, add=0.0): - return times * ds + add - - result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) - - for ( - result_node, - original_node, - ) in zip(result_tree.subtree, dt.subtree): - assert isinstance(result_node, DataTree) - - if original_node.has_data: - assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) - else: - assert not result_node.has_data - - @pytest.mark.xfail - def test_map_over_subtree_inplace(self): - raise NotImplementedError +from datatree import DataNode class TestDSProperties: diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 3d587e3..f13a7f3 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -7,7 +7,21 @@ from datatree.io import open_datatree -def create_test_datatree(): +def assert_tree_equal(dt_a, dt_b): + assert dt_a.name == dt_b.name + assert dt_a.parent is dt_b.parent + + assert dt_a.ds.equals(dt_b.ds) + for a, b in zip(dt_a.descendants, dt_b.descendants): + assert a.name == b.name + assert a.pathstr == b.pathstr + if a.has_data: + assert a.ds.equals(b.ds) + else: + assert a.ds is b.ds + + +def create_test_datatree(modify=lambda ds: ds): """ Create a test datatree with this structure: @@ -37,12 +51,11 @@ def create_test_datatree(): The structure has deliberately repeated names of tags, variables, and dimensions in order to better check for bugs caused by name conflicts. """ - set1_data = xr.Dataset({"a": 0, "b": 1}) - set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) - root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = modify(xr.Dataset({"a": 0, "b": 1})) + set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) + root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) # Avoid using __init__ so we can independently test it - # TODO change so it has a DataTree at the bottom root = DataNode(name="root", data=root_data) set1 = DataNode(name="set1", parent=root, data=set1_data) DataNode(name="set1", parent=set1) diff --git a/datatree/tests/test_mapping.py b/datatree/tests/test_mapping.py new file mode 100644 index 0000000..da2ad8b --- /dev/null +++ b/datatree/tests/test_mapping.py @@ -0,0 +1,184 @@ +import pytest +import xarray as xr +from test_datatree import assert_tree_equal, create_test_datatree +from xarray.testing import assert_equal + +from datatree.datatree import DataNode, DataTree +from datatree.mapping import TreeIsomorphismError, _check_isomorphic, map_over_subtree +from datatree.treenode import TreeNode + +empty = xr.Dataset() + + +class TestCheckTreesIsomorphic: + def test_not_a_tree(self): + with pytest.raises(TypeError, match="not a tree"): + _check_isomorphic("s", 1) + + def test_different_widths(self): + dt1 = DataTree(data_objects={"a": empty}) + dt2 = DataTree(data_objects={"a": empty, "b": empty}) + expected_err_str = ( + "'root' in the first tree has 1 children, whereas its counterpart node 'root' in the " + "second tree has 2 children" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2) + + def test_different_heights(self): + dt1 = DataTree(data_objects={"a": empty}) + dt2 = DataTree(data_objects={"a": empty, "a/b": empty}) + expected_err_str = ( + "'root/a' in the first tree has 0 children, whereas its counterpart node 'root/a' in the " + "second tree has 1 children" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2) + + def test_only_one_has_data(self): + dt1 = DataTree(data_objects={"a": xr.Dataset({"a": 0})}) + dt2 = DataTree(data_objects={"a": None}) + expected_err_str = ( + "'root/a' in the first tree has data, whereas its counterpart node 'root/a' in the " + "second tree has no data" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2) + + def test_names_different(self): + dt1 = DataTree(data_objects={"a": xr.Dataset()}) + dt2 = DataTree(data_objects={"b": empty}) + expected_err_str = ( + "'root/a' in the first tree has name 'a', whereas its counterpart node 'root/b' in the " + "second tree has name 'b'" + ) + with pytest.raises(TreeIsomorphismError, match=expected_err_str): + _check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_names_equal(self): + dt1 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) + dt2 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) + _check_isomorphic(dt1, dt2, require_names_equal=True) + + def test_isomorphic_ordering(self): + dt1 = DataTree( + data_objects={"a": empty, "b": empty, "b/d": empty, "b/c": empty} + ) + dt2 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) + _check_isomorphic(dt1, dt2, require_names_equal=False) + + def test_isomorphic_names_not_equal(self): + dt1 = DataTree( + data_objects={"a": empty, "b": empty, "b/c": empty, "b/d": empty} + ) + dt2 = DataTree( + data_objects={"A": empty, "B": empty, "B/C": empty, "B/D": empty} + ) + _check_isomorphic(dt1, dt2) + + def test_not_isomorphic_complex_tree(self): + dt1 = create_test_datatree() + dt2 = create_test_datatree() + dt2.set_node("set1/set2", TreeNode("set3")) + with pytest.raises(TreeIsomorphismError, match="root/set1/set2"): + _check_isomorphic(dt1, dt2) + + +class TestMapOverSubTree: + @pytest.mark.xfail + def test_no_trees_passed(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_not_isomorphic(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_no_trees_returned(self): + raise NotImplementedError + + def test_single_dt_arg(self): + dt = create_test_datatree() + + @map_over_subtree + def times_ten(ds): + return 10.0 * ds + + result_tree = times_ten(dt) + expected = create_test_datatree(modify=lambda ds: 10.0 * ds) + assert_tree_equal(result_tree, expected) + + def test_single_dt_arg_plus_args_and_kwargs(self): + dt = create_test_datatree() + + @map_over_subtree + def multiply_then_add(ds, times, add=0.0): + return times * ds + add + + result_tree = multiply_then_add(dt, 10.0, add=2.0) + expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) + assert_tree_equal(result_tree, expected) + + @pytest.mark.xfail + def test_multiple_dt_args(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) + dt = DataNode("root", data=ds) + DataNode("results", data=ds + 0.2, parent=dt) + + @map_over_subtree + def add(ds1, ds2): + return ds1 + ds2 + + expected = DataNode("root", data=ds * 2) + DataNode("results", data=(ds + 0.2) * 2, parent=expected) + + result = add(dt, dt) + + # dt1 = create_test_datatree() + # dt2 = create_test_datatree() + # expected = create_test_datatree(modify=lambda ds: 2 * ds) + + assert_tree_equal(result, expected) + + @pytest.mark.xfail + def test_dt_as_kwarg(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_return_multiple_dts(self): + raise NotImplementedError + + @pytest.mark.xfail + def test_return_no_dts(self): + raise NotImplementedError + + def test_dt_method(self): + dt = create_test_datatree() + + def multiply_then_add(ds, times, add=0.0): + return times * ds + add + + result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) + + for ( + result_node, + original_node, + ) in zip(result_tree.subtree, dt.subtree): + assert isinstance(result_node, DataTree) + + if original_node.has_data: + assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0) + else: + assert not result_node.has_data + + +@pytest.mark.xfail +class TestMapOverSubTreeInplace: + def test_map_over_subtree_inplace(self): + raise NotImplementedError diff --git a/datatree/treenode.py b/datatree/treenode.py index 898ee12..276577e 100644 --- a/datatree/treenode.py +++ b/datatree/treenode.py @@ -84,7 +84,7 @@ def _pre_attach(self, parent: TreeNode) -> None: """ if self.name in list(c.name for c in parent.children): raise KeyError( - f"parent {str(parent)} already has a child named {self.name}" + f"parent {parent.name} already has a child named {self.name}" ) def add_child(self, child: TreeNode) -> None: