diff --git a/src/main/java/org/mastodon/mamut/treesimilarity/ZhangUnorderedTreeEditDistance.java b/src/main/java/org/mastodon/mamut/treesimilarity/ZhangUnorderedTreeEditDistance.java index 3af1e8943..f93e65776 100644 --- a/src/main/java/org/mastodon/mamut/treesimilarity/ZhangUnorderedTreeEditDistance.java +++ b/src/main/java/org/mastodon/mamut/treesimilarity/ZhangUnorderedTreeEditDistance.java @@ -28,6 +28,7 @@ */ package org.mastodon.mamut.treesimilarity; +import org.mastodon.mamut.treesimilarity.tree.Node; import org.mastodon.mamut.treesimilarity.tree.Tree; import org.mastodon.mamut.treesimilarity.tree.TreeUtils; import org.mastodon.mamut.treesimilarity.util.FlowNetwork; @@ -200,7 +201,7 @@ private NodeMapping< T > treeMapping() private static < T > double distanceTreeToNull( Tree< T > tree2, ToDoubleBiFunction< T, T > costFunction ) { double distance = 0; - for ( Tree< T > subtree : TreeUtils.listOfSubtrees( tree2 ) ) + for ( Tree< T > subtree : TreeUtils.getAllChildren( tree2 ) ) distance += costFunction.applyAsDouble( null, subtree.getAttribute() ); return distance; } @@ -212,8 +213,10 @@ private ZhangUnorderedTreeEditDistance( final Tree< T > tree1, final Tree< T > t root1 = new CachedTree<>( tree1 ); root2 = new CachedTree<>( tree2 ); - subtrees1 = assignIndices( root1 ); - subtrees2 = assignIndices( root2 ); + subtrees1 = TreeUtils.getAllChildren( root1 ); + subtrees2 = TreeUtils.getAllChildren( root2 ); + subtrees1.forEach( cachedTree -> cachedTree.index = subtrees1.indexOf( cachedTree ) ); + subtrees2.forEach( cachedTree -> cachedTree.index = subtrees2.indexOf( cachedTree ) ); costMatrix = new double[ subtrees1.size() ][ subtrees2.size() ]; for ( CachedTree< T > subtree1 : subtrees1 ) @@ -269,21 +272,6 @@ private void computeChangeCosts( CachedTree< T > tree, ToDoubleBiFunction< T, T tree.forestCost = forestCosts; } - private List< CachedTree< T > > assignIndices( CachedTree< T > cachedTree ) - { - List< CachedTree< T > > list = new ArrayList<>(); - assignIndex( list, cachedTree ); - return list; - } - - private void assignIndex( List< CachedTree< T > > list, CachedTree< T > cachedTree ) - { - list.add( cachedTree ); - cachedTree.index = list.size() - 1; - for ( CachedTree< T > child : cachedTree.children ) - assignIndex( list, child ); - } - /** * Calculate the Zhang edit distance between two (labeled) unordered trees. * @@ -648,7 +636,7 @@ private static < T > NodeMapping< T > findBestMapping( final NodeMapping< T > a, return c; } - private static class CachedTree< T > + private static class CachedTree< T > implements Node< CachedTree< T > > { private int index; @@ -679,7 +667,7 @@ private boolean isLeaf() return isLeaf; } - private List< CachedTree< T > > getChildren() + public List< CachedTree< T > > getChildren() { return children; } diff --git a/src/main/java/org/mastodon/mamut/treesimilarity/tree/Node.java b/src/main/java/org/mastodon/mamut/treesimilarity/tree/Node.java new file mode 100644 index 000000000..5a757c350 --- /dev/null +++ b/src/main/java/org/mastodon/mamut/treesimilarity/tree/Node.java @@ -0,0 +1,13 @@ +package org.mastodon.mamut.treesimilarity.tree; + +import java.util.Collection; + +public interface Node< T > +{ + /** + * Get the children of this {@link Tree}. + * + * @return The list of child {@link Tree} objects. + */ + Collection< T > getChildren(); +} diff --git a/src/main/java/org/mastodon/mamut/treesimilarity/tree/Tree.java b/src/main/java/org/mastodon/mamut/treesimilarity/tree/Tree.java index cf3e6e977..ec9ed6829 100644 --- a/src/main/java/org/mastodon/mamut/treesimilarity/tree/Tree.java +++ b/src/main/java/org/mastodon/mamut/treesimilarity/tree/Tree.java @@ -28,21 +28,13 @@ */ package org.mastodon.mamut.treesimilarity.tree; -import java.util.Collection; - /** * A tree data structure. * * @param the type of the attribute of the tree nodes. */ -public interface Tree< T > +public interface Tree< T > extends Node< Tree< T > > { - /** - * Get the children of this {@link Tree}. - * - * @return The list of child {@link Tree} objects. - */ - Collection< Tree< T > > getChildren(); /** * Get the attribute of this {@link Tree}. diff --git a/src/main/java/org/mastodon/mamut/treesimilarity/tree/TreeUtils.java b/src/main/java/org/mastodon/mamut/treesimilarity/tree/TreeUtils.java index dc526e92b..22239f0a8 100644 --- a/src/main/java/org/mastodon/mamut/treesimilarity/tree/TreeUtils.java +++ b/src/main/java/org/mastodon/mamut/treesimilarity/tree/TreeUtils.java @@ -29,6 +29,7 @@ package org.mastodon.mamut.treesimilarity.tree; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -40,22 +41,6 @@ private TreeUtils() // prevent from instantiation } - /** - * Returns a complete list of all descendant subtrees of the given {@code Tree}, including itself. - * - * @return The list of subtrees. - */ - public static < T > List< Tree< T > > listOfSubtrees( final Tree< T > tree ) - { - if ( tree == null ) - return Collections.emptyList(); - List< Tree< T > > list = new ArrayList<>(); - list.add( tree ); - for ( Tree< T > child : tree.getChildren() ) - list.addAll( listOfSubtrees( child ) ); - return list; - } - /** * Gets the number of descendant subtrees of this {@link Tree}, including itself. * @return the number @@ -64,7 +49,7 @@ public static < T > int size( final Tree< T > tree ) { if ( tree == null ) return 0; - return listOfSubtrees( tree ).size(); + return getAllChildren( tree ).size(); } /** @@ -76,10 +61,40 @@ public static < T > int size( final Tree< T > tree ) public static < T > List< T > getAllAttributes( final Tree< T > tree ) { List< T > attributes = new ArrayList<>(); - listOfSubtrees( tree ).forEach( subtree -> attributes.add( subtree.getAttribute() ) ); + getAllChildren( tree ).forEach( subtree -> attributes.add( subtree.getAttribute() ) ); return attributes; } + /** + * Recursively collects all children of the given node, including the node itself. + * + * @param node The root node. + * @param The type of the node. + * @return A list of all children nodes. + */ + public static < T extends Node< T > > List< T > getAllChildren( T node ) + { + if ( node == null ) + return Collections.emptyList(); + List< T > result = new ArrayList<>(); + result.add( node ); + getAllChildrenRecursive( node, result ); + return result; + } + + private static < T extends Node< T > > void getAllChildrenRecursive( T node, List< T > result ) + { + Collection< T > children = node.getChildren(); + if ( children != null ) + { + for ( T child : children ) + { + result.add( child ); + getAllChildrenRecursive( child, result ); + } + } + } + /** * Creates a String of the given tree as a Java code snippet that can be used to create the tree. * @param tree The tree to print. diff --git a/src/test/java/org/mastodon/mamut/treesimilarity/tree/TreeUtilsTest.java b/src/test/java/org/mastodon/mamut/treesimilarity/tree/TreeUtilsTest.java index 589961a21..c257204db 100644 --- a/src/test/java/org/mastodon/mamut/treesimilarity/tree/TreeUtilsTest.java +++ b/src/test/java/org/mastodon/mamut/treesimilarity/tree/TreeUtilsTest.java @@ -40,7 +40,7 @@ class TreeUtilsTest { @Test - void testListOfSubtrees() + void testGetAllChildren() { Tree< Double > emptyTree = SimpleTreeExamples.emptyTree(); @@ -49,8 +49,8 @@ void testListOfSubtrees() subtrees1.add( tree1 ); subtrees1.addAll( tree1.getChildren() ); - assertEquals( Collections.singletonList( emptyTree ), TreeUtils.listOfSubtrees( emptyTree ) ); - assertEquals( subtrees1, TreeUtils.listOfSubtrees( tree1 ) ); + assertEquals( Collections.singletonList( emptyTree ), TreeUtils.getAllChildren( emptyTree ) ); + assertEquals( subtrees1, TreeUtils.getAllChildren( tree1 ) ); } @Test diff --git a/src/test/java/org/mastodon/mamut/treesimilarity/util/NodeMappingTest.java b/src/test/java/org/mastodon/mamut/treesimilarity/util/NodeMappingTest.java index a1e71b35d..604b3d230 100644 --- a/src/test/java/org/mastodon/mamut/treesimilarity/util/NodeMappingTest.java +++ b/src/test/java/org/mastodon/mamut/treesimilarity/util/NodeMappingTest.java @@ -163,10 +163,10 @@ private double computeCosts( Tree< Double > tree1, Tree< Double > tree2, Map< Tr Set< Tree< Double > > keys = mapping.keySet(); Set< Tree< Double > > values = new HashSet<>( mapping.values() ); double costs = 0; - for ( Tree< Double > subtree : TreeUtils.listOfSubtrees( tree1 ) ) + for ( Tree< Double > subtree : TreeUtils.getAllChildren( tree1 ) ) if ( !keys.contains( subtree ) ) costs += DEFAULT_COSTS.applyAsDouble( subtree.getAttribute(), null ); - for ( Tree< Double > subtree : TreeUtils.listOfSubtrees( tree2 ) ) + for ( Tree< Double > subtree : TreeUtils.getAllChildren( tree2 ) ) if ( !values.contains( subtree ) ) costs += DEFAULT_COSTS.applyAsDouble( subtree.getAttribute(), null ); for ( Map.Entry< Tree< Double >, Tree< Double > > entry : mapping.entrySet() )