Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open for contribution on utility nodes like Filter, Shuffler, Header, Cycler? #1452

Open
keunwoochoi opened this issue Feb 25, 2025 · 1 comment

Comments

@keunwoochoi
Copy link

keunwoochoi commented Feb 25, 2025

Hi, do you think this kind of nodes would be in the scope of Torchdata? Then I'm down to open a PR to add them. with remaining and testing, for sure.

import logging
import random
from collections import deque
from typing import Any, Callable, Deque, Dict, Optional, TypeVar, Optional
from torchdata.nodes import BaseNode

logger = logging.getLogger(__name__)

X = TypeVar("X")
T = TypeVar("T")
U = TypeVar("U")


class Filter(BaseNode[T]):
    """Node that filters items from source node based on predicate function."""

    SOURCE_KEY = "source"

    def __init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]):
        super().__init__()
        self.source = source_node
        self.filter_fn = filter_fn

    def reset(self, initial_state: Optional[Dict[str, Any]] = None):
        super().reset(initial_state)
        self.source.reset(initial_state.get(self.SOURCE_KEY) if initial_state else None)

    def next(self) -> T:
        while True:
            item = next(self.source)
            if self.filter_fn(item):
                return item

    def get_state(self) -> Dict[str, Any]:
        return {self.SOURCE_KEY: self.source.state_dict()}


class Shuffler(BaseNode[T]):
    """Node that shuffles items from source node using a buffer."""

    SOURCE_KEY = "source"

    def __init__(self, source_node: BaseNode[T], buffer_size: int, seed: Optional[int] = None):
        super().__init__()
        if buffer_size < 1:
            raise ValueError("Buffer size must be at least 1")
        self.source = source_node
        self.buffer_size = buffer_size
        self.buffer: Deque[T] = deque()
        self.rng = random.Random(seed)
        self._initial_seed = seed

    def reset(self, initial_state: Optional[Dict[str, Any]] = None):
        super().reset(initial_state)
        self.buffer.clear()

        if initial_state is not None:
            self.source.reset(initial_state.get(self.SOURCE_KEY))
            self.rng.setstate(initial_state["rng_state"])
        else:
            self.source.reset()
            if self._initial_seed is not None:
                self.rng = random.Random(self._initial_seed)

    def _fill_buffer(self) -> bool:
        """Fill buffer with items from source. Returns True if any items were added."""
        try:
            while len(self.buffer) < self.buffer_size:
                self.buffer.append(next(self.source))
            return True
        except StopIteration:
            return len(self.buffer) > 0

    def next(self) -> T:
        if not self.buffer and not self._fill_buffer():
            raise StopIteration

        # Randomly select and remove an item from the buffer
        idx = self.rng.randrange(len(self.buffer))
        item = self.buffer[idx]
        self.buffer[idx] = self.buffer[-1]
        self.buffer.pop()

        # Try to refill buffer
        self._fill_buffer()
        return item

    def get_state(self) -> Dict[str, Any]:
        return {self.SOURCE_KEY: self.source.state_dict(), "rng_state": self.rng.getstate()}


class Header(BaseNode[T]):
    """Node that yields only the first N items from source node."""

    SOURCE_KEY = "source"

    def __init__(self, source_node: BaseNode[T], n: int):
        super().__init__()
        if n < 0:
            raise ValueError("n must be non-negative")
        self.source = source_node
        self.n = n
        self._count = 0

    def reset(self, initial_state: Optional[Dict[str, Any]] = None):
        super().reset(initial_state)
        self.source.reset(initial_state.get(self.SOURCE_KEY) if initial_state else None)
        if initial_state is not None:
            self._count = initial_state["count"]
        else:
            self._count = 0

    def next(self) -> T:
        if self._count >= self.n:
            raise StopIteration

        item = next(self.source)
        self._count += 1
        return item

    def get_state(self) -> Dict[str, Any]:
        return {self.SOURCE_KEY: self.source.state_dict(), "count": self._count}


class Cycler(BaseNode[T]):
    """Node that cycles through source node indefinitely."""

    SOURCE_KEY = "source"

    def __init__(self, source_node: BaseNode[T]):
        super().__init__()
        self.source = source_node
        self._cycle_count: int = 0

    def reset(self, initial_state: Optional[Dict[str, Any]] = None):
        super().reset(initial_state)
        if initial_state is not None:
            self._cycle_count = initial_state["cycle_count"]
            self.source.reset(initial_state.get(self.SOURCE_KEY))
        else:
            self._cycle_count = 0
            self.source.reset(None)

    def next(self) -> T:
        try:
            return next(self.source)
        except StopIteration:
            self._cycle_count += 1
            self.source.reset(None)
            return next(self.source)

    def get_state(self) -> Dict[str, Any]:
        return {self.SOURCE_KEY: self.source.state_dict(), "cycle_count": self._cycle_count}
@keunwoochoi keunwoochoi changed the title Open for contribution on utility nodes like FilterNode, ShuffleNode, HeaderNode, CycleNode? Open for contribution on utility nodes like Filter, Shuffler, Header, Cycler? Feb 25, 2025
@divyanshk
Copy link
Contributor

Hey @keunwoochoi, thanks for this! These would be a great addition, especially excited for Shuffler.

cc. @ramanishsingh who has looking at Filter node.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants