You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, do you think this kind of nodes would be in the scope of Torchdata? Then I'm down to open a PR to add them. with remaining and testing, for sure.
importloggingimportrandomfromcollectionsimportdequefromtypingimportAny, Callable, Deque, Dict, Optional, TypeVar, Optionalfromtorchdata.nodesimportBaseNodelogger=logging.getLogger(__name__)
X=TypeVar("X")
T=TypeVar("T")
U=TypeVar("U")
classFilter(BaseNode[T]):
"""Node that filters items from source node based on predicate function."""SOURCE_KEY="source"def__init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]):
super().__init__()
self.source=source_nodeself.filter_fn=filter_fndefreset(self, initial_state: Optional[Dict[str, Any]] =None):
super().reset(initial_state)
self.source.reset(initial_state.get(self.SOURCE_KEY) ifinitial_stateelseNone)
defnext(self) ->T:
whileTrue:
item=next(self.source)
ifself.filter_fn(item):
returnitemdefget_state(self) ->Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict()}
classShuffler(BaseNode[T]):
"""Node that shuffles items from source node using a buffer."""SOURCE_KEY="source"def__init__(self, source_node: BaseNode[T], buffer_size: int, seed: Optional[int] =None):
super().__init__()
ifbuffer_size<1:
raiseValueError("Buffer size must be at least 1")
self.source=source_nodeself.buffer_size=buffer_sizeself.buffer: Deque[T] =deque()
self.rng=random.Random(seed)
self._initial_seed=seeddefreset(self, initial_state: Optional[Dict[str, Any]] =None):
super().reset(initial_state)
self.buffer.clear()
ifinitial_stateisnotNone:
self.source.reset(initial_state.get(self.SOURCE_KEY))
self.rng.setstate(initial_state["rng_state"])
else:
self.source.reset()
ifself._initial_seedisnotNone:
self.rng=random.Random(self._initial_seed)
def_fill_buffer(self) ->bool:
"""Fill buffer with items from source. Returns True if any items were added."""try:
whilelen(self.buffer) <self.buffer_size:
self.buffer.append(next(self.source))
returnTrueexceptStopIteration:
returnlen(self.buffer) >0defnext(self) ->T:
ifnotself.bufferandnotself._fill_buffer():
raiseStopIteration# Randomly select and remove an item from the bufferidx=self.rng.randrange(len(self.buffer))
item=self.buffer[idx]
self.buffer[idx] =self.buffer[-1]
self.buffer.pop()
# Try to refill bufferself._fill_buffer()
returnitemdefget_state(self) ->Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict(), "rng_state": self.rng.getstate()}
classHeader(BaseNode[T]):
"""Node that yields only the first N items from source node."""SOURCE_KEY="source"def__init__(self, source_node: BaseNode[T], n: int):
super().__init__()
ifn<0:
raiseValueError("n must be non-negative")
self.source=source_nodeself.n=nself._count=0defreset(self, initial_state: Optional[Dict[str, Any]] =None):
super().reset(initial_state)
self.source.reset(initial_state.get(self.SOURCE_KEY) ifinitial_stateelseNone)
ifinitial_stateisnotNone:
self._count=initial_state["count"]
else:
self._count=0defnext(self) ->T:
ifself._count>=self.n:
raiseStopIterationitem=next(self.source)
self._count+=1returnitemdefget_state(self) ->Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict(), "count": self._count}
classCycler(BaseNode[T]):
"""Node that cycles through source node indefinitely."""SOURCE_KEY="source"def__init__(self, source_node: BaseNode[T]):
super().__init__()
self.source=source_nodeself._cycle_count: int=0defreset(self, initial_state: Optional[Dict[str, Any]] =None):
super().reset(initial_state)
ifinitial_stateisnotNone:
self._cycle_count=initial_state["cycle_count"]
self.source.reset(initial_state.get(self.SOURCE_KEY))
else:
self._cycle_count=0self.source.reset(None)
defnext(self) ->T:
try:
returnnext(self.source)
exceptStopIteration:
self._cycle_count+=1self.source.reset(None)
returnnext(self.source)
defget_state(self) ->Dict[str, Any]:
return {self.SOURCE_KEY: self.source.state_dict(), "cycle_count": self._cycle_count}
The text was updated successfully, but these errors were encountered:
keunwoochoi
changed the title
Open for contribution on utility nodes like FilterNode, ShuffleNode, HeaderNode, CycleNode?
Open for contribution on utility nodes like Filter, Shuffler, Header, Cycler?
Feb 25, 2025
Hi, do you think this kind of nodes would be in the scope of Torchdata? Then I'm down to open a PR to add them. with remaining and testing, for sure.
The text was updated successfully, but these errors were encountered: