diff --git a/functionalpy/test_sequence.py b/functionalpy/test_sequence.py index 5d81c14..05dae2a 100644 --- a/functionalpy/test_sequence.py +++ b/functionalpy/test_sequence.py @@ -1,10 +1,13 @@ +from collections.abc import Mapping, Sequence +from typing import Literal + from functionalpy._sequence import Seq # TODO: https://github.com/MartinBernstorff/FunctionalPy/issues/40 tests: add hypothesis tests where it makes sense def test_chaining(): sequence = Seq([1, 2]) - result = ( + result: list[int] = ( sequence.filter(lambda x: x % 2 == 0) .map(lambda x: x * 2) .to_list() @@ -14,66 +17,64 @@ def test_chaining(): def test_map(): sequence = Seq([1, 2]) - result = sequence.map(lambda x: x * 2).to_list() + result: list[int] = sequence.map(lambda x: x * 2).to_list() assert result == [2, 4] def test_filter(): - sequence = Seq((1, 2)) - result = sequence.filter(lambda x: x % 2 == 0).to_list() + sequence = Seq([1, 2]) + result: list[int] = sequence.filter( + lambda x: x % 2 == 0 + ).to_list() assert result == [2] def test_reduce(): sequence = Seq([1, 2]) - result = sequence.reduce(lambda x, y: x + y) + result: int = sequence.reduce(lambda x, y: x + y) assert result == 3 -class TestGroupby: - def test_grouped_filter(self): - sequence = Seq([1, 2, 3, 4]) - - def is_even(num: int) -> str: - if num % 2 == 0: - return "even" - return "odd" - - grouped = sequence.groupby(is_even) - assert grouped == { - "odd": [1, 3], - "even": [2, 4], - } +def test_grouped_filter(): + sequence = Seq([1, 2, 3, 4]) - def test_grouped_map(self): - ... + def is_even(num: int) -> str: + if num % 2 == 0: + return "even" + return "odd" - def test_grouped_reduce(self): - ... + grouped: Mapping[str, Sequence[int]] = sequence.groupby(is_even) + assert grouped == { + "odd": [1, 3], + "even": [2, 4], + } def test_flatten(): - test_input = ((1, 2), (3, 4)) + test_input: list[list[int]] = [[1, 2], [3, 4]] sequence = Seq(test_input) - result = sequence.flatten() + result: Seq[int] = sequence.flatten() assert result.to_list() == [1, 2, 3, 4] class TestFlattenTypes: def test_flatten_tuple(self): - test_input = ((1, 2), (3, 4)) + test_input: Sequence[tuple[int, ...]] = ( + (1, 2), + (3, 4), + ) sequence = Seq(test_input) - result = sequence.flatten() + result: Seq[Literal[1, 2, 3, 4]] = sequence.flatten() assert result.to_list() == [1, 2, 3, 4] def test_flatten_list(self): - test_input = [[1, 2], [3, 4]] + test_input: list[list[int]] = [[1, 2], [3, 4]] sequence = Seq(test_input) - result = sequence.flatten() + result: Seq[int] = sequence.flatten() assert result.to_list() == [1, 2, 3, 4] def test_flatten_str(self): - test_input = ["abcd"] + test_input: list[str] = ["abcd"] sequence = Seq(test_input) - result = sequence.flatten() + result: Seq[str] = sequence.flatten() assert result.to_list() == ["a", "b", "c", "d"]