diff --git a/ibis/common/collections.py b/ibis/common/collections.py index 3a56872afd9b..2793042b011d 100644 --- a/ibis/common/collections.py +++ b/ibis/common/collections.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import MappingProxyType -from typing import Any, Hashable, Iterable, Iterator, Mapping, Set, TypeVar +from typing import Any, Hashable, Mapping, TypeVar from public import public from typing_extensions import Self @@ -208,225 +208,4 @@ def __repr__(self): return f"{self.__class__.__name__}({super().__repr__()})" -class DisjointSet(Mapping[K, Set[K]]): - """Disjoint set data structure. - - Also known as union-find data structure. It is a data structure that keeps - track of a set of elements partitioned into a number of disjoint (non-overlapping) - subsets. It provides near-constant-time operations to add new sets, to merge - existing sets, and to determine whether elements are in the same set. - - Parameters - ---------- - data : - Initial data to add to the disjoint set. - - Examples - -------- - >>> ds = DisjointSet() - >>> ds.add(1) - 1 - >>> ds.add(2) - 2 - >>> ds.add(3) - 3 - >>> ds.union(1, 2) - True - >>> ds.union(2, 3) - True - >>> ds.find(1) - 1 - >>> ds.find(2) - 1 - >>> ds.find(3) - 1 - >>> ds.union(1, 3) - False - """ - - __slots__ = ("_parents", "_classes") - - def __init__(self, data: Iterable[K] | None = None): - self._parents = {} - self._classes = {} - if data is not None: - for id in data: - self.add(id) - - def __contains__(self, id) -> bool: - """Check if the given id is in the disjoint set. - - Parameters - ---------- - id : - The id to check. - - Returns - ------- - ined: - True if the id is in the disjoint set, False otherwise. - """ - return id in self._parents - - def __getitem__(self, id) -> set[K]: - """Get the set of ids that are in the same class as the given id. - - Parameters - ---------- - id : - The id to get the class for. - - Returns - ------- - class: - The set of ids that are in the same class as the given id, including - the given id. - """ - id = self._parents[id] - return self._classes[id] - - def __iter__(self) -> Iterator[K]: - """Iterate over the ids in the disjoint set.""" - return iter(self._parents) - - def __len__(self) -> int: - """Get the number of ids in the disjoint set.""" - return len(self._parents) - - def __eq__(self, other: Self[K]) -> bool: - """Check if the disjoint set is equal to another disjoint set. - - Parameters - ---------- - other : - The other disjoint set to compare to. - - Returns - ------- - equal: - True if the disjoint sets are equal, False otherwise. - """ - if not isinstance(other, DisjointSet): - return NotImplemented - return self._parents == other._parents - - def add(self, id: K) -> K: - """Add a new id to the disjoint set. - - If the id is not in the disjoint set, it will be added to the disjoint set - along with a new class containing only the given id. - - Parameters - ---------- - id : - The id to add to the disjoint set. - - Returns - ------- - id: - The id that was added to the disjoint set. - """ - if id in self._parents: - return self._parents[id] - self._parents[id] = id - self._classes[id] = {id} - return id - - def find(self, id: K) -> K: - """Find the root of the class that the given id is in. - - Also called as the canonicalized id or the representative id. - - Parameters - ---------- - id : - The id to find the canonicalized id for. - - Returns - ------- - id: - The canonicalized id for the given id. - """ - return self._parents[id] - - def union(self, id1, id2) -> bool: - """Merge the classes that the given ids are in. - - If the ids are already in the same class, this will return False. Otherwise - it will merge the classes and return True. - - Parameters - ---------- - id1 : - The first id to merge the classes for. - id2 : - The second id to merge the classes for. - - Returns - ------- - merged: - True if the classes were merged, False otherwise. - """ - # Find the root of each class - id1 = self._parents[id1] - id2 = self._parents[id2] - if id1 == id2: - return False - - # Merge the smaller eclass into the larger one, aka. union-find by size - class1 = self._classes[id1] - class2 = self._classes[id2] - if len(class1) >= len(class2): - id1, id2 = id2, id1 - class1, class2 = class2, class1 - - # Update the parent pointers, this is called path compression but done - # during the union operation to keep the find operation minimal - for id in class1: - self._parents[id] = id2 - - # Do the actual merging and clear the other eclass - class2 |= class1 - class1.clear() - - return True - - def connected(self, id1, id2): - """Check if the given ids are in the same class. - - True if both ids have the same canonicalized id, False otherwise. - - Parameters - ---------- - id1 : - The first id to check. - id2 : - The second id to check. - - Returns - ------- - connected: - True if the ids are connected, False otherwise. - """ - return self._parents[id1] == self._parents[id2] - - def verify(self): - """Verify that the disjoint set is not corrupted. - - Check that each id's canonicalized id's class. In general corruption - should not happen if the public API is used, but this is a sanity check - to make sure that the internal data structures are not corrupted. - - Returns - ------- - verified: - True if the disjoint set is not corrupted, False otherwise. - """ - for id in self._parents: - if id not in self._classes[self._parents[id]]: - raise RuntimeError( - f"DisjointSet is corrupted: {id} is not in its class" - ) - - public(frozendict=FrozenDict, dotdict=DotDict) diff --git a/ibis/common/egraph.py b/ibis/common/egraph.py index 53259d26a8df..8e0e26ade400 100644 --- a/ibis/common/egraph.py +++ b/ibis/common/egraph.py @@ -3,12 +3,237 @@ import collections import itertools import math -from typing import Any +from collections.abc import Iterable, Iterator, Mapping, Set +from typing import Any, Hashable, TypeVar + +from typing_extensions import Self -from ibis.common.collections import DisjointSet from ibis.common.graph import Node from ibis.util import promote_list +K = TypeVar("K", bound=Hashable) + + +class DisjointSet(Mapping[K, Set[K]]): + """Disjoint set data structure. + + Also known as union-find data structure. It is a data structure that keeps + track of a set of elements partitioned into a number of disjoint (non-overlapping) + subsets. It provides near-constant-time operations to add new sets, to merge + existing sets, and to determine whether elements are in the same set. + + Parameters + ---------- + data : + Initial data to add to the disjoint set. + + Examples + -------- + >>> ds = DisjointSet() + >>> ds.add(1) + 1 + >>> ds.add(2) + 2 + >>> ds.add(3) + 3 + >>> ds.union(1, 2) + True + >>> ds.union(2, 3) + True + >>> ds.find(1) + 1 + >>> ds.find(2) + 1 + >>> ds.find(3) + 1 + >>> ds.union(1, 3) + False + """ + + __slots__ = ("_parents", "_classes") + + def __init__(self, data: Iterable[K] | None = None): + self._parents = {} + self._classes = {} + if data is not None: + for id in data: + self.add(id) + + def __contains__(self, id) -> bool: + """Check if the given id is in the disjoint set. + + Parameters + ---------- + id : + The id to check. + + Returns + ------- + ined: + True if the id is in the disjoint set, False otherwise. + """ + return id in self._parents + + def __getitem__(self, id) -> set[K]: + """Get the set of ids that are in the same class as the given id. + + Parameters + ---------- + id : + The id to get the class for. + + Returns + ------- + class: + The set of ids that are in the same class as the given id, including + the given id. + """ + id = self._parents[id] + return self._classes[id] + + def __iter__(self) -> Iterator[K]: + """Iterate over the ids in the disjoint set.""" + return iter(self._parents) + + def __len__(self) -> int: + """Get the number of ids in the disjoint set.""" + return len(self._parents) + + def __eq__(self, other: Self[K]) -> bool: + """Check if the disjoint set is equal to another disjoint set. + + Parameters + ---------- + other : + The other disjoint set to compare to. + + Returns + ------- + equal: + True if the disjoint sets are equal, False otherwise. + """ + if not isinstance(other, DisjointSet): + return NotImplemented + return self._parents == other._parents + + def add(self, id: K) -> K: + """Add a new id to the disjoint set. + + If the id is not in the disjoint set, it will be added to the disjoint set + along with a new class containing only the given id. + + Parameters + ---------- + id : + The id to add to the disjoint set. + + Returns + ------- + id: + The id that was added to the disjoint set. + """ + if id in self._parents: + return self._parents[id] + self._parents[id] = id + self._classes[id] = {id} + return id + + def find(self, id: K) -> K: + """Find the root of the class that the given id is in. + + Also called as the canonicalized id or the representative id. + + Parameters + ---------- + id : + The id to find the canonicalized id for. + + Returns + ------- + id: + The canonicalized id for the given id. + """ + return self._parents[id] + + def union(self, id1, id2) -> bool: + """Merge the classes that the given ids are in. + + If the ids are already in the same class, this will return False. Otherwise + it will merge the classes and return True. + + Parameters + ---------- + id1 : + The first id to merge the classes for. + id2 : + The second id to merge the classes for. + + Returns + ------- + merged: + True if the classes were merged, False otherwise. + """ + # Find the root of each class + id1 = self._parents[id1] + id2 = self._parents[id2] + if id1 == id2: + return False + + # Merge the smaller eclass into the larger one, aka. union-find by size + class1 = self._classes[id1] + class2 = self._classes[id2] + if len(class1) >= len(class2): + id1, id2 = id2, id1 + class1, class2 = class2, class1 + + # Update the parent pointers, this is called path compression but done + # during the union operation to keep the find operation minimal + for id in class1: + self._parents[id] = id2 + + # Do the actual merging and clear the other eclass + class2 |= class1 + class1.clear() + + return True + + def connected(self, id1, id2): + """Check if the given ids are in the same class. + + True if both ids have the same canonicalized id, False otherwise. + + Parameters + ---------- + id1 : + The first id to check. + id2 : + The second id to check. + + Returns + ------- + connected: + True if the ids are connected, False otherwise. + """ + return self._parents[id1] == self._parents[id2] + + def verify(self): + """Verify that the disjoint set is not corrupted. + + Check that each id's canonicalized id's class. In general corruption + should not happen if the public API is used, but this is a sanity check + to make sure that the internal data structures are not corrupted. + + Returns + ------- + verified: + True if the disjoint set is not corrupted, False otherwise. + """ + for id in self._parents: + if id not in self._classes[self._parents[id]]: + raise RuntimeError( + f"DisjointSet is corrupted: {id} is not in its class" + ) + class Slotted: """A lightweight alternative to `ibis.common.grounds.Concrete`. diff --git a/ibis/common/tests/test_collections.py b/ibis/common/tests/test_collections.py index c24129481f49..d7b0b389b3f4 100644 --- a/ibis/common/tests/test_collections.py +++ b/ibis/common/tests/test_collections.py @@ -2,7 +2,7 @@ import pytest -from ibis.common.collections import DisjointSet, DotDict, FrozenDict, MapSet +from ibis.common.collections import DotDict, FrozenDict, MapSet from ibis.tests.util import assert_pickle_roundtrip @@ -221,73 +221,3 @@ def test_frozendict(): assert hash(d) assert_pickle_roundtrip(d) - - -def test_disjoint_set(): - ds = DisjointSet() - ds.add(1) - ds.add(2) - ds.add(3) - ds.add(4) - - ds1 = DisjointSet([1, 2, 3, 4]) - assert ds == ds1 - assert ds[1] == {1} - assert ds[2] == {2} - assert ds[3] == {3} - assert ds[4] == {4} - - assert ds.union(1, 2) is True - assert ds[1] == {1, 2} - assert ds[2] == {1, 2} - assert ds.union(2, 3) is True - assert ds[1] == {1, 2, 3} - assert ds[2] == {1, 2, 3} - assert ds[3] == {1, 2, 3} - assert ds.union(1, 3) is False - assert ds[4] == {4} - assert ds != ds1 - assert 1 in ds - assert 2 in ds - assert 5 not in ds - - assert ds.find(1) == 1 - assert ds.find(2) == 1 - assert ds.find(3) == 1 - assert ds.find(4) == 4 - - assert ds.connected(1, 2) is True - assert ds.connected(1, 3) is True - assert ds.connected(1, 4) is False - - # test mapping api get - assert ds.get(1) == {1, 2, 3} - assert ds.get(4) == {4} - assert ds.get(5) is None - assert ds.get(5, 5) == 5 - assert ds.get(5, default=5) == 5 - - # test mapping api keys - assert set(ds.keys()) == {1, 2, 3, 4} - assert set(ds) == {1, 2, 3, 4} - - # test mapping api values - assert tuple(ds.values()) == ({1, 2, 3}, {1, 2, 3}, {1, 2, 3}, {4}) - - # test mapping api items - assert tuple(ds.items()) == ( - (1, {1, 2, 3}), - (2, {1, 2, 3}), - (3, {1, 2, 3}), - (4, {4}), - ) - - # check that the disjoint set doesn't get corrupted by adding an existing element - ds.verify() - ds.add(1) - ds.verify() - - with pytest.raises(RuntimeError, match="DisjointSet is corrupted"): - ds._parents[1] = 1 - ds._classes[1] = {1} - ds.verify() diff --git a/ibis/common/tests/test_egraph.py b/ibis/common/tests/test_egraph.py index 22315b2821b2..11f93f014662 100644 --- a/ibis/common/tests/test_egraph.py +++ b/ibis/common/tests/test_egraph.py @@ -6,13 +6,82 @@ import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.common.collections import DisjointSet -from ibis.common.egraph import EGraph, ENode, Pattern, Rewrite, Variable +from ibis.common.egraph import DisjointSet, EGraph, ENode, Pattern, Rewrite, Variable from ibis.common.graph import Graph, Node from ibis.common.grounds import Concrete from ibis.util import promote_tuple +def test_disjoint_set(): + ds = DisjointSet() + ds.add(1) + ds.add(2) + ds.add(3) + ds.add(4) + + ds1 = DisjointSet([1, 2, 3, 4]) + assert ds == ds1 + assert ds[1] == {1} + assert ds[2] == {2} + assert ds[3] == {3} + assert ds[4] == {4} + + assert ds.union(1, 2) is True + assert ds[1] == {1, 2} + assert ds[2] == {1, 2} + assert ds.union(2, 3) is True + assert ds[1] == {1, 2, 3} + assert ds[2] == {1, 2, 3} + assert ds[3] == {1, 2, 3} + assert ds.union(1, 3) is False + assert ds[4] == {4} + assert ds != ds1 + assert 1 in ds + assert 2 in ds + assert 5 not in ds + + assert ds.find(1) == 1 + assert ds.find(2) == 1 + assert ds.find(3) == 1 + assert ds.find(4) == 4 + + assert ds.connected(1, 2) is True + assert ds.connected(1, 3) is True + assert ds.connected(1, 4) is False + + # test mapping api get + assert ds.get(1) == {1, 2, 3} + assert ds.get(4) == {4} + assert ds.get(5) is None + assert ds.get(5, 5) == 5 + assert ds.get(5, default=5) == 5 + + # test mapping api keys + assert set(ds.keys()) == {1, 2, 3, 4} + assert set(ds) == {1, 2, 3, 4} + + # test mapping api values + assert tuple(ds.values()) == ({1, 2, 3}, {1, 2, 3}, {1, 2, 3}, {4}) + + # test mapping api items + assert tuple(ds.items()) == ( + (1, {1, 2, 3}), + (2, {1, 2, 3}), + (3, {1, 2, 3}), + (4, {4}), + ) + + # check that the disjoint set doesn't get corrupted by adding an existing element + ds.verify() + ds.add(1) + ds.verify() + + with pytest.raises(RuntimeError, match="DisjointSet is corrupted"): + ds._parents[1] = 1 + ds._classes[1] = {1} + ds.verify() + + class PatternNamespace: def __init__(self, module): self.module = module