Skip to content

Commit

Permalink
add doc strings, separate out stop critera in its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
divyanshk committed Nov 15, 2024
1 parent bb60d96 commit 16f32ee
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 32 deletions.
2 changes: 1 addition & 1 deletion test/nodes/test_multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchdata.nodes.prefetch import Prefetcher

from torchdata.nodes.samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler
from torchdata.nodes.samplers.utils import StopCriteria
from torchdata.nodes.samplers.stop_criteria import StopCriteria

from .utils import DummyIterableDataset, run_test_save_load_state

Expand Down
14 changes: 10 additions & 4 deletions torchdata/nodes/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.root = loader.root
self._cached_item = None
self._num_yielded = 0
self._cached_state_dict: Optional[Dict[str, Any]] = None

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
Expand All @@ -62,6 +63,7 @@ def has_next(self) -> bool:
if self._cached_item is None:
try:
self._cached_item = next(self)
self._cached_state_dict = self.state_dict()
except StopIteration:
pass
return self._cached_item is not None
Expand All @@ -70,13 +72,17 @@ def next(self):
if self._cached_item is not None:
item = self._cached_item
self._cached_item = None
self._cached_state_dict = None
return item
else:
item = next(self.root)
return item

def get_state(self) -> Dict[str, Any]:
return {
self.ROOT_KEY: self.root.state_dict(),
self.NUM_YIELDED_KEY: self._num_yielded,
}
if self._cached_state_dict is not None:
return self._cached_state_dict
else:
return {
self.ROOT_KEY: self.root.state_dict(),
self.NUM_YIELDED_KEY: self._num_yielded,
}
43 changes: 35 additions & 8 deletions torchdata/nodes/samplers/multi_node_weighted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,44 @@

import torch
from torchdata.nodes.base_node import BaseNode, T
from torchdata.nodes.samplers.utils import StopCriteria
from torchdata.nodes.samplers.stop_criteria import StopCriteria

from .utils import _get_rank_seed, get_rank_and_world_size


class MultiNodeWeightedSampler(BaseNode[T]):
"""A node that samples from multiple datasets with weights."""
"""A node that samples from multiple datasets with weights.
This node expects to take in a dictionary of source nodes, and a dictionary of weights.
The keys of the source nodes and weights must be the same. The weights are used to sample
from the source nodes. We use torch.multinomial to sample from the source nodes, please
refer to https://pytorch.org/docs/stable/generated/torch.multinomial.html on how to use
weights for sampling. `seed` is used to initialize the random number generator.
The node implements the state using the following keys:
- NUM_YIELDED_KEY: The number of items yielded.
- WEIGHTED_SAMPLER_STATE_KEY: The state of the weighted sampler.
- DATASETS_EXHAUSTED_KEY: A dictionary of booleans indicating whether each source node is exhausted.
- DATASET_NODE_STATES_KEY: A dictionary of states for each source node.
We support multiple stopping criteria:
- CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Cycle through the source nodes until all datasets
are exhausted. This is the default behavior.
- FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.
- ALL_DATASETS_EXHAUSTED: Stop when all datasets are exhausted.
On complete exhaustion of the source nodes, the node will raise StopIteration.
Parameters:
source_nodes (Mapping[str, BaseNode[T]]): A dictionary of source nodes.
weights (Dict[str, float]): A dictionary of weights for each source node.
stop_criteria (str): The stopping criteria. Default is CYCLE_UNTIL_ALL_DATASETS_EXHAUST
rank (int): The rank of the current process. Default is None, in which case the rank
will be obtained from the distributed environment.
world_size (int): The world size of the distributed environment. Default is None, in
which case the world size will be obtained from the distributed environment.
seed (int): The seed for the random number generator. Default is 0.
"""

DATASET_NODE_STATES_KEY = "dataset_node_states"
NUM_YIELDED_KEY = "num_yielded"
Expand Down Expand Up @@ -45,8 +76,6 @@ def __init__(
self.rank = rank
self.world_size = world_size

self.epoch = 0

self._validate()

def _validate(self) -> None:
Expand Down Expand Up @@ -126,7 +155,6 @@ def next(self) -> T:
if self._datasets_exhausted[key] and self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED:
# Before fetching a new item check if key corresponds to an already
# exhaused dataset and StopCriteria is ALL_DATASETS_EXHAUSTED, move to next key
# return next(self) # omit recursive call
continue
item = next(self.source_nodes[key])
except StopIteration:
Expand All @@ -138,11 +166,10 @@ def next(self) -> T:

# If StopCriteria is ALL_DATASETS_EXHAUSTED, move to next key
if self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED:
# return next(self) # omit recursive call
# key = next(self._weighted_sampler)
continue

# Reset the iterator and try again
# If StopCriteria is CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
# reset the iterator and try again
self.source_nodes[key].reset()
item = next(self.source_nodes[key])
break
Expand Down
24 changes: 24 additions & 0 deletions torchdata/nodes/samplers/stop_criteria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


class StopCriteria:
"""
Stopping criteria for the dataset samplers.
1) CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Stop once the last unseen dataset is exhausted.
All datasets are seen at least once. In certain cases, some datasets may be
seen more than once when there are still non-exhausted datasets.
2) ALL_DATASETS_EXHAUSTED: Stop once all have the datasets are exhausted. Each
dataset is seen exactly once. No wraparound or restart will be performed.
3) FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.
"""

CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED = "CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED"
ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED"
FIRST_DATASET_EXHAUSTED = "FIRST_DATASET_EXHAUSTED"
19 changes: 0 additions & 19 deletions torchdata/nodes/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,6 @@
import torch.distributed as dist


class StopCriteria:
"""
Stopping criteria for the dataset samplers.
1) CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Stop once the last unseen dataset is exhausted.
All datasets are seen at least once. In certain cases, some datasets may be
seen more than once when there are still non-exhausted datasets.
2) ALL_DATASETS_EXHAUSTED: Stop once all have the datasets are exhausted. Each
dataset is seen exactly once. No wraparound or restart will be performed.
3) FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.
"""

CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED = "CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED"
ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED"
FIRST_DATASET_EXHAUSTED = "FIRST_DATASET_EXHAUSTED"


def _get_rank_seed(seed: int, generator_rank: torch.Generator, rank: int, world_size: int) -> int:
generator_rank.manual_seed(seed * world_size + rank)
return int(torch.randint(0, 2 ** 32 - 1, size=(1,), generator=generator_rank).item())
Expand Down

0 comments on commit 16f32ee

Please sign in to comment.