Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Filter #1454

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions test/nodes/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import itertools

from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.filter import Filter

from .utils import MockSource, run_test_save_load_state, StatefulRangeNode


class TestFilter(TestCase):
def test_filter_basic(self) -> None:
# Test with a simple range
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x % 2 == 0) # Keep even numbers

results = list(node)
self.assertEqual(results, [0, 2, 4, 6, 8])

# Test with a different predicate
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x > 5) # Keep numbers greater than 5

results = list(node)
self.assertEqual(results, [6, 7, 8, 9])

def test_filter_with_mock_source(self) -> None:
num_samples = 20
source = MockSource(num_samples=num_samples)
node = Filter(source, lambda x: x["step"] % 3 == 0) # Keep items where step is divisible by 3

# Test multi epoch
for _ in range(2):
node.reset()
results = list(node)
expected_steps = [i for i in range(num_samples) if i % 3 == 0]
self.assertEqual(len(results), len(expected_steps))

for i, result in enumerate(results):
expected_step = expected_steps[i]
self.assertEqual(result["step"], expected_step)
self.assertEqual(result["test_tensor"].item(), expected_step)
self.assertEqual(result["test_str"], f"str_{expected_step}")

def test_filter_empty_result(self) -> None:
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x > 100) # No items will pass this filter

results = list(node)
self.assertEqual(results, [])

@parameterized.expand(itertools.product([0, 3, 7]))
def test_save_load_state(self, midpoint: int):
n = 50
source = StatefulRangeNode(n=n)
node = Filter(source, lambda x: x['i'] % 3 == 0) # Keep items where 'i' is divisible by 3
run_test_save_load_state(self, node, midpoint)

def test_filter_reset_state(self) -> None:
source = IterableWrapper(range(10))
node = Filter(source, lambda x: x % 2 == 0)

# Consume first two items
self.assertEqual(next(node), 0)
self.assertEqual(next(node), 2)

# Get state and reset
state = node.state_dict()
node.reset(state)

# Should continue from where we left off
self.assertEqual(next(node), 4)
self.assertEqual(next(node), 6)
self.assertEqual(next(node), 8)

# Should raise StopIteration after all items are consumed
with self.assertRaises(StopIteration):
next(node)
1 change: 1 addition & 0 deletions torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
from .base_node import BaseNode, T
from .batch import Batcher, Unbatcher
from .filter import Filter
from .loader import Loader
from .map import Mapper, ParallelMapper
from .pin_memory import PinMemory
Expand Down
35 changes: 35 additions & 0 deletions torchdata/nodes/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Callable, Dict, TypeVar, Optional
from torchdata.nodes import BaseNode


T = TypeVar("T")


class Filter(BaseNode[T]):
"""Node that filters items from source node based on predicate function.

Args:
source_node (BaseNode[T]): The source node to filter items from.
filter_fn (Callable[[T], bool]): A function that takes an item and returns True if the item
should be included, False otherwise.
"""

SOURCE_KEY = "source"

def __init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]):
super().__init__()
self.source = source_node
self.filter_fn = filter_fn

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
self.source.reset(initial_state.get(self.SOURCE_KEY) if initial_state else None)

def next(self) -> T:
while True:
item = next(self.source)
if self.filter_fn(item):
return item

def get_state(self) -> Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict()}