Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added handling of hashing of types with args and typing special forms #684

Merged
merged 11 commits into from
Sep 8, 2023
46 changes: 34 additions & 12 deletions pydra/utils/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

# import stat
import struct
import typing as ty
from collections.abc import Mapping
from functools import singledispatch
from hashlib import blake2b
import logging

# from pathlib import Path
from typing import (
Expand All @@ -14,10 +16,11 @@
NewType,
Sequence,
Set,
_SpecialForm,
)
import attrs.exceptions

logger = logging.getLogger("pydra")

try:
from typing import Protocol
except ImportError:
Expand Down Expand Up @@ -88,7 +91,8 @@ def hash_single(obj: object, cache: Cache) -> Hash:
h = blake2b(digest_size=16, person=b"pydra-hash")
for chunk in bytes_repr(obj, cache):
h.update(chunk)
cache[objid] = Hash(h.digest())
hsh = cache[objid] = Hash(h.digest())
logger.debug("Hash of %s object is %s", obj, hsh)
return cache[objid]


Expand All @@ -102,15 +106,17 @@ def __bytes_repr__(self, cache: Cache) -> Iterator[bytes]:
def bytes_repr(obj: object, cache: Cache) -> Iterator[bytes]:
cls = obj.__class__
yield f"{cls.__module__}.{cls.__name__}:{{".encode()
try:
dct = obj.__dict__
except AttributeError as e:
# Attrs creates slots classes by default, so we add this here to handle those
# cases
if attrs.has(type(obj)):
# Drop any attributes that aren't used in comparisons by default
dct = attrs.asdict(obj, recurse=False, filter=lambda a, _: bool(a.eq)) # type: ignore
else:
try:
tclose marked this conversation as resolved.
Show resolved Hide resolved
dct = attrs.asdict(obj, recurse=False) # type: ignore
except attrs.exceptions.NotAnAttrsClassError:
raise TypeError(f"Cannot hash {obj} as it is a slots class") from e
dct = obj.__dict__
except AttributeError as e:
try:
dct = {n: getattr(obj, n) for n in obj.__slots__} # type: ignore
except AttributeError:
raise e
tclose marked this conversation as resolved.
Show resolved Hide resolved
yield from bytes_repr_mapping_contents(dct, cache)
yield b"}"

Expand Down Expand Up @@ -224,10 +230,26 @@ def bytes_repr_dict(obj: dict, cache: Cache) -> Iterator[bytes]:
yield b"}"


@register_serializer(_SpecialForm)
@register_serializer(ty._GenericAlias)
@register_serializer(ty._SpecialForm)
@register_serializer(type)
def bytes_repr_type(klass: type, cache: Cache) -> Iterator[bytes]:
yield f"type:({klass.__module__}.{klass.__name__})".encode()
def type_name(tp):
try:
name = tp.__name__
except AttributeError:
name = tp._name
effigies marked this conversation as resolved.
Show resolved Hide resolved
return name

yield b"type:("
origin = ty.get_origin(klass)
if origin:
yield f"{origin.__module__}.{type_name(origin)}[".encode()
yield from bytes_repr_sequence_contents(ty.get_args(klass), cache)
tclose marked this conversation as resolved.
Show resolved Hide resolved
yield b"]"
else:
yield f"{klass.__module__}.{type_name(klass)}".encode()
yield b")"


@register_serializer(list)
Expand Down
71 changes: 69 additions & 2 deletions pydra/utils/tests/test_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import attrs
import pytest

import typing as ty
from fileformats.application import Zip, Json
from ..hash import Cache, UnhashableError, bytes_repr, hash_object, register_serializer


Expand Down Expand Up @@ -134,6 +135,20 @@ def __init__(self, x):
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)


def test_bytes_repr_slots_obj():
class MyClass:
__slots__ = ("x",)

def __init__(
self,
x,
):
tclose marked this conversation as resolved.
Show resolved Hide resolved
self.x = x

obj_repr = join_bytes_repr(MyClass(1))
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)


def test_bytes_repr_attrs_slots():
@attrs.define
class MyClass:
Expand All @@ -143,11 +158,63 @@ class MyClass:
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)


def test_bytes_repr_type():
def test_bytes_repr_attrs_no_slots():
@attrs.define(slots=False)
class MyClass:
x: int

obj_repr = join_bytes_repr(MyClass(1))
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)


def test_bytes_repr_type1():
obj_repr = join_bytes_repr(Path)
assert obj_repr == b"type:(pathlib.Path)"


def test_bytes_repr_type1a():
obj_repr = join_bytes_repr(Zip[Json])
assert re.match(rb"type:\(fileformats.application.archive.Json__Zip\)", obj_repr)
tclose marked this conversation as resolved.
Show resolved Hide resolved


def test_bytes_repr_type2():
T = ty.TypeVar("T")

class MyClass(ty.Generic[T]):
pass

obj_repr = join_bytes_repr(MyClass[int])
assert re.match(rb"type:\(pydra.utils.tests.test_hash.MyClass\[.{16}\]\)", obj_repr)
tclose marked this conversation as resolved.
Show resolved Hide resolved


def test_bytes_special_form1():
obj_repr = join_bytes_repr(ty.Union[int, float])
assert re.match(rb"type:\(typing.Union\[.{32}\]\)", obj_repr)


def test_bytes_special_form2():
obj_repr = join_bytes_repr(ty.Any)
assert re.match(rb"type:\(typing.Any\)", obj_repr)


def test_bytes_special_form3():
obj_repr = join_bytes_repr(ty.Optional[Path])
assert re.match(rb"type:\(typing.Union\[.{32}\]\)", obj_repr, flags=re.DOTALL)


def test_bytes_special_form4():
obj_repr = join_bytes_repr(ty.Type[Path])
assert re.match(rb"type:\(builtins.type\[.{16}\]\)", obj_repr, flags=re.DOTALL)


def test_bytes_special_form5():
obj_repr = join_bytes_repr(ty.Callable[[Path, int], ty.Tuple[float, str]])
assert re.match(
rb"type:\(collections.abc.Callable\[.{32}\]\)", obj_repr, flags=re.DOTALL
)
assert obj_repr != join_bytes_repr(ty.Callable[[Path, int], ty.Tuple[float, bytes]])


def test_recursive_object():
a = []
b = [a]
Expand Down
2 changes: 1 addition & 1 deletion pydra/utils/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...engine.specs import File, LazyOutField
from ..typing import TypeParser
from pydra import Workflow
from fileformats.serialization import Json
from fileformats.application import Json
from .utils import (
generic_func_task,
GenericShellTask,
Expand Down