diff --git a/iterpy/_iter.py b/iterpy/_iter.py index 5f7647f..f4d85b1 100644 --- a/iterpy/_iter.py +++ b/iterpy/_iter.py @@ -114,6 +114,8 @@ def flatten(self) -> "Iter[T]": for i in self._iterator: if isinstance(i, Sequence) and not isinstance(i, str): values.extend(i) + elif isinstance(i, Iter): + values.extend(i.to_list()) else: values.append(i) diff --git a/iterpy/test_iter.py b/iterpy/test_iter.py index 373fb80..88d823f 100644 --- a/iterpy/test_iter.py +++ b/iterpy/test_iter.py @@ -124,7 +124,8 @@ def test_flatten_iterator(self): def test_flatten_iter_iter(self): iterator: Iter[int] = Iter([1, 2]) nested_iter: Iter[Iter[int]] = Iter([iterator]) - unnested_iter: Iter[int] = nested_iter.flatten() # noqa: F841, RUF100 + unnested_iter: Iter[int] = nested_iter.flatten() + assert unnested_iter.to_list() == [1, 2] def test_flatten_str(self): test_input: list[str] = ["abcd"]