Skip to content

Commit

Permalink
Merge pull request #243 from Oxid15/filter
Browse files Browse the repository at this point in the history
Filter
  • Loading branch information
Oxid15 authored Jun 12, 2024
2 parents 1e73047 + 0f68fe8 commit a7f18e1
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 1 deletion.
1 change: 1 addition & 0 deletions cascade/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SizedDataset,
Wrapper,
)
from .filter import Filter, IteratorFilter
from .folder_dataset import FolderDataset
from .functions import dataset, modifier
from .modifier import BaseModifier, IteratorModifier, Modifier, Sampler
Expand Down
4 changes: 3 additions & 1 deletion cascade/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import warnings
from abc import abstractmethod
from typing import Any, Generator, Generic, Iterable, Sequence, Sized, TypeVar
from typing import Any, Generator, Generic, Iterable, Iterator, Sequence, Sized, TypeVar

from ..base import PipeMeta, Traceable

Expand Down Expand Up @@ -54,6 +54,8 @@ class IteratorDataset(BaseDataset[T], Iterable[T]):
An abstract class to represent a dataset as
an iterable object
"""
def __iter__(self) -> Iterator[T]:
return super().__iter__()


class Dataset(BaseDataset[T], Sized):
Expand Down
81 changes: 81 additions & 0 deletions cascade/data/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Copyright 2022-2024 Ilia Moiseev
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Any, Callable, Iterator

from .dataset import Dataset, IteratorDataset
from .modifier import IteratorModifier, Sampler


class Filter(Sampler):
"""
Filter for Datasets with length. Uses a function
to create a mask of items that will remain
"""

def __init__(self, dataset: Dataset, filter_fn: Callable, *args: Any, **kwargs: Any) -> None:
"""
Filter a dataset using a filter function.
Does not accumulate items in memory, will store only an index mask.
Parameters
----------
dataset: Dataset
A dataset to filter
filter_fn: Callable
A function to be applied to every item of a dataset -
should return bool. Will be called on every item on `__init__`.
Raises
------
RuntimeError
If `filter_fn` raises an exception
"""
self._mask = []
for i in range(len(dataset)):
try:
result = filter_fn(dataset[i])
if result:
self._mask.append(i)
except Exception as e:
raise RuntimeError(f"Error when filtering dataset on index: {i}") from e
super().__init__(dataset, len(self._mask), *args, **kwargs)

def __getitem__(self, index: Any):
return self._dataset[self._mask[index]]


class IteratorFilter(IteratorModifier):
"""
Filter for datasets without length
Does not filter on init, returns only items that pass the filter
"""
def __init__(
self, dataset: IteratorDataset, filter_fn: Callable, *args: Any, **kwargs: Any
) -> None:
self._filter_fn = filter_fn
super().__init__(dataset, *args, **kwargs)

def __next__(self):
while True:
item = next(self._dataset)
try:
result = self._filter_fn(item)
if result:
return item
except Exception as e:
raise RuntimeError("Error when filtering iterator") from e
83 changes: 83 additions & 0 deletions cascade/tests/data/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Copyright 2022-2024 Ilia Moiseev
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import random
import sys

import pytest

SCRIPT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from cascade.data import Filter, IteratorDataset, IteratorFilter, Wrapper


@pytest.mark.parametrize(
"arr, func, res",
[
([1, 2, 3, 4, 5], lambda x: x < 2, [1]),
(["a", "aa", "aaa", ""], lambda x: len(x), ["a", "aa", "aaa"]),
],
)
def test_filter(arr, func, res):
ds = Wrapper(arr)
ds = Filter(ds, func)
assert [item for item in ds] == res


def test_empty_filter():
ds = Wrapper([0, 1, 2, 3, 4])

with pytest.raises(AssertionError):
ds = Filter(ds, lambda x: x > 4)


def test_runtime_error():
def two_is_bad(i):
if i == 2:
raise ValueError("No, 2 is bad!!!")
return True

ds = Wrapper([0, 1, 2, 3])

with pytest.raises(RuntimeError):
ds = Filter(ds, two_is_bad)


def test_iter():
class RandomDataStream(IteratorDataset):
def __next__(self):
if random.random() < 0.2:
raise StopIteration()
else:
return random.randint(0, 255)

ds = RandomDataStream()
ds = IteratorFilter(ds, lambda x: x < 127)

assert all([item < 127 for item in ds])


def test_empty_iter():
class EmptyStream(IteratorDataset):
def __next__(self):
raise StopIteration()

ds = EmptyStream()
ds = IteratorFilter(ds, lambda x: True)

assert [] == [item for item in ds]

0 comments on commit a7f18e1

Please sign in to comment.