Skip to content

Commit

Permalink
Check isomorphism xarray-contrib/datatree#31
Browse files Browse the repository at this point in the history
* pseudocode ideas for generalizing map_over_subtree

* pseudocode for a generalized map_over_subtree (still only one return arg) + a new mapping.py file

* pseudocode for mapping but now multiple return values

* pseudocode for mapping but with multiple return values

* check_isomorphism works and has tests

* cleaned up the mapping tests a bit

* remove WIP from oter branch

* ensure tests pass

* map_over_subtree in the public API properly

* linting
  • Loading branch information
TomNicholas authored Aug 27, 2021
1 parent fa69ad7 commit 6807504
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 132 deletions.
3 changes: 2 additions & 1 deletion datatree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa
# Ignoring F401: imported but unused
from .datatree import DataNode, DataTree, map_over_subtree
from .datatree import DataNode, DataTree
from .io import open_datatree
from .mapping import map_over_subtree
58 changes: 1 addition & 57 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import functools
import textwrap
from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Union

Expand All @@ -14,6 +13,7 @@
from xarray.core.ops import NAN_CUM_METHODS, NAN_REDUCE_METHODS, REDUCE_METHODS
from xarray.core.variable import Variable

from .mapping import map_over_subtree
from .treenode import PathType, TreeNode, _init_single_treenode

"""
Expand Down Expand Up @@ -50,62 +50,6 @@
"""


def map_over_subtree(func):
"""
Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees.
Applies a function to every dataset in this subtree, returning a new tree which stores the results.
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
descendant nodes. The returned tree will have the same structure as the original subtree.
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
Parameters
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
**kwargs : Any
Keyword arguments passed on to `func`.
Returns
-------
mapped : callable
Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node.
See also
--------
DataTree.map_over_subtree
DataTree.map_over_subtree_inplace
"""

@functools.wraps(func)
def _map_over_subtree(tree, *args, **kwargs):
"""Internal function which maps func over every node in tree, returning a tree of the results."""

# Recreate and act on root node
out_tree = DataNode(name=tree.name, data=tree.ds)
if out_tree.has_data:
out_tree.ds = func(out_tree.ds, *args, **kwargs)

# Act on every other node in the tree, and rebuild from results
for node in tree.descendants:
# TODO make a proper relative_path method
relative_path = node.pathstr.replace(tree.pathstr, "")
result = func(node.ds, *args, **kwargs) if node.has_data else None
out_tree[relative_path] = result

return out_tree

return _map_over_subtree


class DatasetPropertiesMixin:
"""Expose properties of wrapped Dataset"""

Expand Down
139 changes: 139 additions & 0 deletions datatree/mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import functools

from anytree.iterators import LevelOrderIter

from .treenode import TreeNode


class TreeIsomorphismError(ValueError):
"""Error raised if two tree objects are not isomorphic to one another when they need to be."""

pass


def _check_isomorphic(subtree_a, subtree_b, require_names_equal=False):
"""
Check that two trees have the same structure, raising an error if not.
Does not check the actual data in the nodes, but it does check that if one node does/doesn't have data then its
counterpart in the other tree also does/doesn't have data.
Also does not check that the root nodes of each tree have the same parent - so this function checks that subtrees
are isomorphic, not the entire tree above (if it exists).
Can optionally check if respective nodes should have the same name.
Parameters
----------
subtree_a : DataTree
subtree_b : DataTree
require_names_equal : Bool, optional
Whether or not to also check that each node has the same name as its counterpart. Default is False.
Raises
------
TypeError
If either subtree_a or subtree_b are not tree objects.
TreeIsomorphismError
If subtree_a and subtree_b are tree objects, but are not isomorphic to one another, or one contains data at a
location the other does not. Also optionally raised if their structure is isomorphic, but the names of any two
respective nodes are not equal.
"""
# TODO turn this into a public function called assert_isomorphic

if not isinstance(subtree_a, TreeNode):
raise TypeError(
f"Argument `subtree_a is not a tree, it is of type {type(subtree_a)}"
)
if not isinstance(subtree_b, TreeNode):
raise TypeError(
f"Argument `subtree_b is not a tree, it is of type {type(subtree_b)}"
)

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking by walking in this way implicitly assumes that the tree is an ordered tree (which it is so long as
# children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(subtree_a), LevelOrderIter(subtree_b)):
path_a, path_b = node_a.pathstr, node_b.pathstr

if require_names_equal:
if node_a.name != node_b.name:
raise TreeIsomorphismError(
f"Trees are not isomorphic because node '{path_a}' in the first tree has "
f"name '{node_a.name}', whereas its counterpart node '{path_b}' in the "
f"second tree has name '{node_b.name}'."
)

if node_a.has_data != node_b.has_data:
dat_a = "no " if not node_a.has_data else ""
dat_b = "no " if not node_b.has_data else ""
raise TreeIsomorphismError(
f"Trees are not isomorphic because node '{path_a}' in the first tree has "
f"{dat_a}data, whereas its counterpart node '{path_b}' in the second tree "
f"has {dat_b}data."
)

if len(node_a.children) != len(node_b.children):
raise TreeIsomorphismError(
f"Trees are not isomorphic because node '{path_a}' in the first tree has "
f"{len(node_a.children)} children, whereas its counterpart node '{path_b}' in "
f"the second tree has {len(node_b.children)} children."
)


def map_over_subtree(func):
"""
Decorator which turns a function which acts on (and returns) single Datasets into one which acts on DataTrees.
Applies a function to every dataset in this subtree, returning a new tree which stores the results.
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
descendant nodes. The returned tree will have the same structure as the original subtree.
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
Parameters
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
**kwargs : Any
Keyword arguments passed on to `func`.
Returns
-------
mapped : callable
Wrapped function which returns tree created from results of applying ``func`` to the dataset at each node.
See also
--------
DataTree.map_over_subtree
DataTree.map_over_subtree_inplace
"""

@functools.wraps(func)
def _map_over_subtree(tree, *args, **kwargs):
"""Internal function which maps func over every node in tree, returning a tree of the results."""

# Recreate and act on root node
from .datatree import DataNode

out_tree = DataNode(name=tree.name, data=tree.ds)
if out_tree.has_data:
out_tree.ds = func(out_tree.ds, *args, **kwargs)

# Act on every other node in the tree, and rebuild from results
for node in tree.descendants:
# TODO make a proper relative_path method
relative_path = node.pathstr.replace(tree.pathstr, "")
result = func(node.ds, *args, **kwargs) if node.has_data else None
out_tree[relative_path] = result

return out_tree

return _map_over_subtree
69 changes: 1 addition & 68 deletions datatree/tests/test_dataset_api.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,9 @@
import numpy as np
import pytest
import xarray as xr
from test_datatree import create_test_datatree
from xarray.testing import assert_equal

from datatree import DataNode, DataTree, map_over_subtree


class TestMapOverSubTree:
def test_map_over_subtree(self):
dt = create_test_datatree()

@map_over_subtree
def times_ten(ds):
return 10.0 * ds

result_tree = times_ten(dt)

# TODO write an assert_tree_equal function
for (
result_node,
original_node,
) in zip(result_tree.subtree, dt.subtree):
assert isinstance(result_node, DataTree)

if original_node.has_data:
assert_equal(result_node.ds, original_node.ds * 10.0)
else:
assert not result_node.has_data

def test_map_over_subtree_with_args_and_kwargs(self):
dt = create_test_datatree()

@map_over_subtree
def multiply_then_add(ds, times, add=0.0):
return times * ds + add

result_tree = multiply_then_add(dt, 10.0, add=2.0)

for (
result_node,
original_node,
) in zip(result_tree.subtree, dt.subtree):
assert isinstance(result_node, DataTree)

if original_node.has_data:
assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0)
else:
assert not result_node.has_data

def test_map_over_subtree_method(self):
dt = create_test_datatree()

def multiply_then_add(ds, times, add=0.0):
return times * ds + add

result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0)

for (
result_node,
original_node,
) in zip(result_tree.subtree, dt.subtree):
assert isinstance(result_node, DataTree)

if original_node.has_data:
assert_equal(result_node.ds, (original_node.ds * 10.0) + 2.0)
else:
assert not result_node.has_data

@pytest.mark.xfail
def test_map_over_subtree_inplace(self):
raise NotImplementedError
from datatree import DataNode


class TestDSProperties:
Expand Down
23 changes: 18 additions & 5 deletions datatree/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,21 @@
from datatree.io import open_datatree


def create_test_datatree():
def assert_tree_equal(dt_a, dt_b):
assert dt_a.name == dt_b.name
assert dt_a.parent is dt_b.parent

assert dt_a.ds.equals(dt_b.ds)
for a, b in zip(dt_a.descendants, dt_b.descendants):
assert a.name == b.name
assert a.pathstr == b.pathstr
if a.has_data:
assert a.ds.equals(b.ds)
else:
assert a.ds is b.ds


def create_test_datatree(modify=lambda ds: ds):
"""
Create a test datatree with this structure:
Expand Down Expand Up @@ -37,12 +51,11 @@ def create_test_datatree():
The structure has deliberately repeated names of tags, variables, and
dimensions in order to better check for bugs caused by name conflicts.
"""
set1_data = xr.Dataset({"a": 0, "b": 1})
set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = modify(xr.Dataset({"a": 0, "b": 1}))
set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}))
root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}))

# Avoid using __init__ so we can independently test it
# TODO change so it has a DataTree at the bottom
root = DataNode(name="root", data=root_data)
set1 = DataNode(name="set1", parent=root, data=set1_data)
DataNode(name="set1", parent=set1)
Expand Down
Loading

0 comments on commit 6807504

Please sign in to comment.