diff --git a/my/core/common.py b/my/core/common.py index ccebaf2a..85b9386d 100644 --- a/my/core/common.py +++ b/my/core/common.py @@ -6,7 +6,25 @@ import os import sys import types -from typing import Union, Callable, Dict, Iterable, TypeVar, Sequence, List, Optional, Any, cast, Tuple, TYPE_CHECKING, NoReturn +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + NoReturn, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + TypeVar, + Union, + cast, + get_args, + get_type_hints, + get_origin, +) import warnings from . import warnings as core_warnings @@ -628,6 +646,59 @@ def assert_never(value: NoReturn) -> NoReturn: assert False, f'Unhandled value: {value} ({type(value).__name__})' +def _check_all_hashable(fun): + # TODO ok, take callable? + hints = get_type_hints(fun) + # TODO needs to be defensive like in cachew? + return_type = hints.get('return') + # TODO check if None + origin = get_origin(return_type) # Iterator etc? + (arg,) = get_args(return_type) + # options we wanna handle are simple type on the top level or union + arg_origin = get_origin(arg) + + if sys.version_info[:2] >= (3, 10): + is_uniontype = arg_origin is types.UnionType + else: + is_uniontype = False + + is_union = arg_origin is Union or is_uniontype + if is_union: + to_check = get_args(arg) + else: + to_check = (arg,) + + no_hash = [ + t + for t in to_check + # seems that objects that have not overridden hash have the attribute but it's set to None + if getattr(t, '__hash__', None) is None + ] + assert len(no_hash) == 0, f'Types {no_hash} are not hashable, this will result in significant performance downgrade for unique_everseen' + + +_UET = TypeVar('_UET') +_UEU = TypeVar('_UEU') + + +def unique_everseen( + fun: Callable[[], Iterable[_UET]], + key: Optional[Callable[[_UET], _UEU]] = None, +) -> Iterator[_UET]: + # TODO support normal iterable as well? + import more_itertools + + # NOTE: it has to take original callable, because otherwise we don't have access to generator type annotations + iterable = fun() + + if key is None: + # todo check key return type as well? but it's more likely to be hashable + if os.environ.get('HPI_CHECK_UNIQUE_EVERSEEN') is not None: + _check_all_hashable(fun) + + return more_itertools.unique_everseen(iterable=iterable, key=key) + + ## legacy imports, keeping them here for backwards compatibility from functools import cached_property as cproperty from typing import Literal diff --git a/my/fbmessenger/android.py b/my/fbmessenger/android.py index d14b6539..fa313ea3 100644 --- a/my/fbmessenger/android.py +++ b/my/fbmessenger/android.py @@ -9,9 +9,8 @@ import sqlite3 from typing import Iterator, Sequence, Optional, Dict, Union, List -from more_itertools import unique_everseen - from my.core import get_files, Paths, datetime_aware, Res, assert_never, LazyLogger, make_config +from my.core.common import unique_everseen from my.core.error import echain from my.core.sqlite import sqlite_connection @@ -242,7 +241,7 @@ def messages() -> Iterator[Res[Message]]: senders: Dict[str, Sender] = {} msgs: Dict[str, Message] = {} threads: Dict[str, Thread] = {} - for x in unique_everseen(_entities()): + for x in unique_everseen(_entities): if isinstance(x, Exception): yield x continue diff --git a/my/instagram/android.py b/my/instagram/android.py index eace1c0a..ea5ee35b 100644 --- a/my/instagram/android.py +++ b/my/instagram/android.py @@ -10,8 +10,6 @@ import sqlite3 from typing import Iterator, Sequence, Optional, Dict, Union -from more_itertools import unique_everseen - from my.core import ( get_files, Paths, @@ -22,6 +20,7 @@ Res, assert_never, ) +from my.core.common import unique_everseen from my.core.cachew import mcachew from my.core.error import echain from my.core.sqlite import sqlite_connect_immutable, select @@ -196,7 +195,7 @@ def _entities() -> Iterator[Res[Union[User, _Message]]]: @mcachew(depends_on=inputs) def messages() -> Iterator[Res[Message]]: id2user: Dict[str, User] = {} - for x in unique_everseen(_entities()): + for x in unique_everseen(_entities): if isinstance(x, Exception): yield x continue diff --git a/my/instagram/gdpr.py b/my/instagram/gdpr.py index a42d73a9..233f040c 100644 --- a/my/instagram/gdpr.py +++ b/my/instagram/gdpr.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Iterator, Sequence, Dict, Union -from more_itertools import bucket, unique_everseen +from more_itertools import bucket from my.core import ( get_files, @@ -17,6 +17,7 @@ assert_never, make_logger, ) +from my.core.common import unique_everseen from my.config import instagram as user_config @@ -196,7 +197,7 @@ def _entitites_from_path(path: Path) -> Iterator[Res[Union[User, _Message]]]: # TODO basically copy pasted from android.py... hmm def messages() -> Iterator[Res[Message]]: id2user: Dict[str, User] = {} - for x in unique_everseen(_entities()): + for x in unique_everseen(_entities): if isinstance(x, Exception): yield x continue diff --git a/my/tinder/android.py b/my/tinder/android.py index 0ba97391..7e5f5352 100644 --- a/my/tinder/android.py +++ b/my/tinder/android.py @@ -11,9 +11,8 @@ import sqlite3 from typing import Sequence, Iterator, Union, Dict, List, Mapping -from more_itertools import unique_everseen - from my.core import Paths, get_files, Res, assert_never, stat, Stats, datetime_aware, make_logger +from my.core.common import unique_everseen from my.core.error import echain from my.core.sqlite import sqlite_connection import my.config @@ -162,7 +161,7 @@ def _parse_msg(row: sqlite3.Row) -> _Message: def entities() -> Iterator[Res[Entity]]: id2person: Dict[str, Person] = {} id2match: Dict[str, Match] = {} - for x in unique_everseen(_entities()): + for x in unique_everseen(_entities): if isinstance(x, Exception): yield x continue diff --git a/my/twitter/talon.py b/my/twitter/talon.py index e43f6004..306a7350 100644 --- a/my/twitter/talon.py +++ b/my/twitter/talon.py @@ -9,9 +9,8 @@ import sqlite3 from typing import Iterator, Sequence, Union -from more_itertools import unique_everseen - from my.core import Paths, Res, datetime_aware, get_files +from my.core.common import unique_everseen from my.core.sqlite import sqlite_connection from .common import TweetId, permalink @@ -133,7 +132,7 @@ def _parse_tweet(row: sqlite3.Row) -> Tweet: def tweets() -> Iterator[Res[Tweet]]: - for x in unique_everseen(_entities()): + for x in unique_everseen(_entities): if isinstance(x, Exception): yield x elif isinstance(x, _IsTweet): @@ -141,7 +140,7 @@ def tweets() -> Iterator[Res[Tweet]]: def likes() -> Iterator[Res[Tweet]]: - for x in unique_everseen(_entities()): + for x in unique_everseen(_entities): if isinstance(x, Exception): yield x elif isinstance(x, _IsFavorire): diff --git a/my/vk/vk_messages_backup.py b/my/vk/vk_messages_backup.py index 089605b8..18373856 100644 --- a/my/vk/vk_messages_backup.py +++ b/my/vk/vk_messages_backup.py @@ -5,12 +5,12 @@ from datetime import datetime from dataclasses import dataclass import json -from typing import Dict, Iterator, NamedTuple +from typing import Dict, Iterator -from more_itertools import unique_everseen import pytz -from my.core import stat, Stats, Json, Res, datetime_aware +from my.core import stat, Stats, Json, Res, datetime_aware, get_files +from my.core.common import unique_everseen from my.config import vk_messages_backup as config @@ -147,7 +147,7 @@ def _messages() -> Iterator[Res[Message]]: def messages() -> Iterator[Res[Message]]: # seems that during backup messages were sometimes duplicated.. - yield from unique_everseen(_messages()) + yield from unique_everseen(_messages) def stats() -> Stats: diff --git a/my/whatsapp/android.py b/my/whatsapp/android.py index b82c3534..295d8318 100644 --- a/my/whatsapp/android.py +++ b/my/whatsapp/android.py @@ -9,9 +9,8 @@ import sqlite3 from typing import Sequence, Iterator, Optional -from more_itertools import unique_everseen - from my.core import get_files, Paths, datetime_aware, Res, make_logger, make_config +from my.core.common import unique_everseen from my.core.error import echain, notnone from my.core.sqlite import sqlite_connection import my.config @@ -202,4 +201,4 @@ def _messages() -> Iterator[Res[Message]]: def messages() -> Iterator[Res[Message]]: - yield from unique_everseen(_messages()) + yield from unique_everseen(_messages)