Skip to content

Commit

Permalink
my.core: fix list constructor in always_support_sequence and add some…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
karlicoss committed Sep 22, 2024
1 parent 02dabe9 commit 3166109
Showing 1 changed file with 121 additions and 11 deletions.
132 changes: 121 additions & 11 deletions my/core/hpi_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Contains various backwards compatibility/deprecation helpers relevant to HPI itself.
(as opposed to .compat module which implements compatibility between python versions)
"""

import inspect
import os
import re
Expand Down Expand Up @@ -116,32 +117,141 @@ def _get_dal(cfg, module_name: str):
# named to be kinda consistent with more_itertools, e.g. more_itertools.always_iterable
class always_supports_sequence(Iterator[V]):
"""
Helper to make migration from Sequence/List to Iterable/Iterator type backwards compatible
Helper to make migration from Sequence/List to Iterable/Iterator type backwards compatible in runtime
"""

def __init__(self, it: Iterator[V]) -> None:
self.it = it
self._list: Optional[List] = None
self._it = it
self._list: Optional[List[V]] = None
self._lit: Optional[Iterator[V]] = None

def __iter__(self) -> Iterator[V]: # noqa: PYI034
return self.it.__iter__()
if self._list is not None:
self._lit = iter(self._list)
return self

def __next__(self) -> V:
return self.it.__next__()
if self._list is not None:
assert self._lit is not None
delegate = self._lit
else:
delegate = self._it
return next(delegate)

def __getattr__(self, name):
return getattr(self.it, name)
return getattr(self._it, name)

@property
def aslist(self) -> List[V]:
def _aslist(self) -> List[V]:
if self._list is None:
qualname = getattr(self.it, '__qualname__', '<no qualname>') # defensive just in case
qualname = getattr(self._it, '__qualname__', '<no qualname>') # defensive just in case
warnings.medium(f'Using {qualname} as list is deprecated. Migrate to iterative processing or call list() explicitly.')
self._list = list(self.it)
self._list = list(self._it)

# this is necessary for list constructor to work correctly
# since it's __iter__ first, then tries to compute length and then starts iterating...
self._lit = iter(self._list)
return self._list

def __len__(self) -> int:
return len(self.aslist)
return len(self._aslist)

def __getitem__(self, i: int) -> V:
return self.aslist[i]
return self._aslist[i]


def test_always_supports_sequence_list_constructor() -> None:
exhausted = 0

def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1

sit = always_supports_sequence(it())

# list constructor is a bit special... it's trying to compute length if it's available to optimize memory allocation
# so, what's happening in this case is
# - sit.__iter__ is called
# - sit.__len__ is called
# - sit.__next__ is called
res = list(sit)
assert res == ['a', 'b', 'c']
assert exhausted == 1

res = list(sit)
assert res == ['a', 'b', 'c']
assert exhausted == 1 # this will iterate over 'cached' list now, so original generator is only exhausted once


def test_always_supports_sequence_indexing() -> None:
exhausted = 0

def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1

sit = always_supports_sequence(it())

assert len(sit) == 3
assert exhausted == 1

assert sit[2] == 'c'
assert sit[1] == 'b'
assert sit[0] == 'a'
assert exhausted == 1

# a few tests to make sure list-like operations are working..
assert list(sit) == ['a', 'b', 'c']
assert [x for x in sit] == ['a', 'b', 'c'] # noqa: C416
assert list(sit) == ['a', 'b', 'c']
assert [x for x in sit] == ['a', 'b', 'c'] # noqa: C416
assert exhausted == 1


def test_always_supports_sequence_next() -> None:
exhausted = 0

def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1

sit = always_supports_sequence(it())

x = next(sit)
assert x == 'a'
assert exhausted == 0

x = next(sit)
assert x == 'b'
assert exhausted == 0


def test_always_supports_sequence_iter() -> None:
exhausted = 0

def it() -> Iterator[str]:
nonlocal exhausted
yield from ['a', 'b', 'c']
exhausted += 1

sit = always_supports_sequence(it())

for x in sit:
assert x == 'a'
break

x = next(sit)
assert x == 'b'

assert exhausted == 0

x = next(sit)
assert x == 'c'
assert exhausted == 0

for _ in sit:
raise RuntimeError # shouldn't trigger, just exhaust the iterator
assert exhausted == 1

0 comments on commit 3166109

Please sign in to comment.