Skip to content

Commit

Permalink
refactor(common): make FrozenDict a subclass of dict (#8693)
Browse files Browse the repository at this point in the history
This was motivated to work around pandas not repr'ing `frozendict`
elements properly (seen in #8687), but while poking at that I found:

- The existing code was unnecessarily nested (the actual dict was boxed
in a `MappingProxyType` which was boxed in a `FrozenDict` - we can do
better by just using storing the data in the `FrozenDict` itself).
- There was a bug in the hash implementation where the order mattered,
meaning that `hash(frozendict(a=1, b=2)) != hash(frozendict(b=2, a=1))`.
This has since been fixed.

Fixes #8687.

---------

Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
Co-authored-by: Krisztián Szűcs <szucs.krisztian@gmail.com>
  • Loading branch information
3 people authored Mar 21, 2024
1 parent db79aae commit 32b7514
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 34 deletions.
3 changes: 3 additions & 0 deletions ibis/backends/pandas/tests/test_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def test_struct_field_literal(value):
result = con.execute(expr)
assert result == 0

expr = struct.cast("struct<fruit: string, weight: float64>")
assert con.execute(expr) == {"fruit": "pear", "weight": 0.0}


def test_struct_field_series(struct_table):
t = struct_table
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def test_scalar_param_date(backend, alltypes, value):
"risingwave",
"datafusion",
"clickhouse",
"polars",
"sqlite",
"impala",
"oracle",
Expand Down
39 changes: 11 additions & 28 deletions ibis/common/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import collections.abc
from abc import abstractmethod
from itertools import tee
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from public import public
Expand Down Expand Up @@ -276,42 +275,26 @@ def isdisjoint(self, other: collections.abc.Mapping) -> bool:


@public
class FrozenDict(Mapping[K, V], Hashable):
"""Immutable dictionary with a precomputed hash value."""

__slots__ = ("__view__", "__precomputed_hash__")
__view__: MappingProxyType
class FrozenDict(dict, Mapping[K, V], Hashable):
__slots__ = ("__precomputed_hash__",)
__precomputed_hash__: int

def __init__(self, *args, **kwargs):
dictview = MappingProxyType(dict(*args, **kwargs))
dicthash = hash(tuple(dictview.items()))
object.__setattr__(self, "__view__", dictview)
object.__setattr__(self, "__precomputed_hash__", dicthash)
super().__init__(*args, **kwargs)
hashable = frozenset(self.items())
object.__setattr__(self, "__precomputed_hash__", hash(hashable))

def __str__(self):
return str(self.__view__)
def __hash__(self) -> int:
return self.__precomputed_hash__

def __repr__(self):
return f"{self.__class__.__name__}({dict(self.__view__)!r})"
def __setitem__(self, key: K, value: V) -> None:
raise TypeError("'FrozenDict' object does not support item assignment")

def __setattr__(self, name: str, _: Any) -> None:
raise TypeError(f"Attribute {name!r} cannot be assigned to frozendict")

def __reduce__(self):
return self.__class__, (dict(self.__view__),)

def __iter__(self):
return iter(self.__view__)

def __len__(self):
return len(self.__view__)

def __getitem__(self, key):
return self.__view__[key]

def __hash__(self):
return self.__precomputed_hash__
def __reduce__(self) -> tuple:
return (self.__class__, (dict(self),))


class RewindableIterator(Iterator[V]):
Expand Down
7 changes: 2 additions & 5 deletions ibis/common/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,9 @@ def test_frozendict():
with pytest.raises(TypeError, match=msg):
d["d"] = 4

with pytest.raises(TypeError):
d.__view__["a"] = 2
with pytest.raises(TypeError):
d.__view__ = {"a": 2}
assert hash(FrozenDict(a=1, b=2)) == hash(FrozenDict(b=2, a=1))
assert hash(FrozenDict(a=1, b=2)) != hash(d)

assert hash(d)
assert_pickle_roundtrip(d)


Expand Down

0 comments on commit 32b7514

Please sign in to comment.