Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fault-tolerance for global random state in map-style datasets #8950

Merged
merged 108 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
1488879
add LightningFetcher
tchaton Aug 13, 2021
9a5037a
add lightning fetcher
tchaton Aug 13, 2021
6e6e93c
update changelog
tchaton Aug 13, 2021
f4c99a8
typying
tchaton Aug 13, 2021
4412855
add fault tolerant
tchaton Aug 13, 2021
be899aa
Merge branch 'master' into add_lightning_prefetcher_2_n
tchaton Aug 16, 2021
5c54e95
bad merge
tchaton Aug 16, 2021
29c7938
remove prints
tchaton Aug 16, 2021
d1789c8
update
tchaton Aug 16, 2021
3d81454
remove random code
tchaton Aug 16, 2021
64ad33d
fix docstrings and typing
awaelchli Aug 16, 2021
9e1c8a6
resolve todo, rename metadata collate function
awaelchli Aug 16, 2021
91bd840
general cleanup
awaelchli Aug 16, 2021
3ae2a43
fix typo in comment
awaelchli Aug 17, 2021
3ad3afc
update changelog
awaelchli Aug 17, 2021
dd7fc13
remove unused code in apply_to_collection
awaelchli Aug 17, 2021
9579432
Merge branch 'master' into thomas/add_lightning_prefetcher_2_n
awaelchli Aug 17, 2021
4e8697e
random state
awaelchli Aug 17, 2021
e5bb75f
clean up
awaelchli Aug 17, 2021
e65f523
clean out non-global random state (will come in future PR)
awaelchli Aug 17, 2021
909f8ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2021
c23e740
clean out debug statements
awaelchli Aug 17, 2021
dc9525c
fix import
awaelchli Aug 17, 2021
163c486
update data fetcher
awaelchli Aug 17, 2021
6267d01
update cycle iterator code
awaelchli Aug 17, 2021
ce41f56
remove unused function
awaelchli Aug 17, 2021
84aed29
remove unused test
awaelchli Aug 17, 2021
44424c9
remove redundant fix
awaelchli Aug 17, 2021
302e39d
remove state_dict set to None
awaelchli Aug 17, 2021
7af4273
revert type hint Any -> int
awaelchli Aug 17, 2021
7e76efc
rename lastest -> latest
awaelchli Aug 17, 2021
2041809
reword exception message
awaelchli Aug 17, 2021
09787bd
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
a3488ab
update type hint
awaelchli Aug 17, 2021
6414ced
remove my own todo
awaelchli Aug 17, 2021
ff6d5ca
remove my own todo
awaelchli Aug 17, 2021
f9b326d
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
4254558
fix latest worker id
awaelchli Aug 17, 2021
a1f84b9
remove unused methods (introduced in follow up PR)
awaelchli Aug 17, 2021
75dc31b
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
a9fc999
re-introduce methods
awaelchli Aug 17, 2021
08f2503
rename CollectionIteratorState -> MergedIteratorState
awaelchli Aug 17, 2021
c5c7021
update docs
awaelchli Aug 17, 2021
cf6384a
load_state_dict -> from_state_dict
awaelchli Aug 17, 2021
1d6770c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2021
b336628
add comment
tchaton Aug 17, 2021
0254043
Merge branch 'add_lightning_prefetcher_2_n' of https://github.com/PyT…
tchaton Aug 17, 2021
746e512
Merge branch 'thomas/add_lightning_prefetcher_2_n' into thomas/fault-…
awaelchli Aug 17, 2021
bb8f923
Merge branch 'master' into thomas/fault-tolerant-rng-state
awaelchli Aug 18, 2021
544fd1b
fix merge error
awaelchli Aug 18, 2021
b640c3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2021
eed46fd
fix merge error and test
awaelchli Aug 18, 2021
0eae170
revert
awaelchli Aug 18, 2021
766f4b6
revert
awaelchli Aug 18, 2021
b6300c1
reset
awaelchli Aug 18, 2021
3518477
comments and linter
awaelchli Aug 18, 2021
17f6270
unused import
awaelchli Aug 18, 2021
4f73ae7
revert
awaelchli Aug 18, 2021
b787dd9
revert
awaelchli Aug 18, 2021
d8301be
remove redundant "current-iteration" setting
awaelchli Aug 18, 2021
65c1e09
drop random state for sampler for now.
awaelchli Aug 18, 2021
4e7984f
unused imports
awaelchli Aug 18, 2021
bf5c94d
update changelog
awaelchli Aug 18, 2021
45af2c8
clean up test
awaelchli Aug 18, 2021
a72507f
clean up comments in test
awaelchli Aug 18, 2021
1509716
add test docstrings
awaelchli Aug 18, 2021
6799bcf
fix test
awaelchli Aug 18, 2021
9b66c2b
rename test dataset
awaelchli Aug 18, 2021
376c674
move a note
awaelchli Aug 18, 2021
2713917
remove redundant sampler
awaelchli Aug 18, 2021
c2c454c
refactor test
awaelchli Aug 18, 2021
e3e2316
add None check when iterator not available in test
awaelchli Aug 18, 2021
09182d3
update dataloader state loading in fit loop
awaelchli Aug 18, 2021
d68ea88
add comments for dataloader reload
awaelchli Aug 18, 2021
dc4781f
fix test
awaelchli Aug 18, 2021
4795389
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2021
e390d12
remove redundant deletion
awaelchli Aug 20, 2021
e181c68
Update pytorch_lightning/trainer/supporters.py
awaelchli Aug 20, 2021
b8ebb1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2021
03efdbc
update doc string
awaelchli Aug 20, 2021
2c8c59f
change has_completed default to False and update docs
awaelchli Aug 20, 2021
ab1d0f9
Merge branch 'master' into thomas/fault-tolerant-rng-state
awaelchli Aug 23, 2021
5665bd3
fix line too long
awaelchli Aug 23, 2021
c97f04f
Merge branch 'master' into thomas/fault-tolerant-rng-state
awaelchli Aug 24, 2021
c7f2dba
Update pytorch_lightning/trainer/supporters.py
awaelchli Aug 24, 2021
c3c89b8
Update pytorch_lightning/trainer/supporters.py
awaelchli Aug 24, 2021
1a1a6e7
remove duplicate reset
awaelchli Aug 24, 2021
61af0d6
unblock unreachable code
awaelchli Aug 24, 2021
e727d8b
set to any
awaelchli Aug 24, 2021
b38e1ce
Remove unneccesary ANY
carmocca Aug 24, 2021
a8d2d26
add type
awaelchli Aug 25, 2021
b3602fb
Merge branch 'master' into thomas/fault-tolerant-rng-state
awaelchli Aug 25, 2021
bd9f44d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 25, 2021
5000f2e
remove extra iterator call
awaelchli Aug 25, 2021
4bfc7bd
fix global step and epoch counters on failed checkpointing
awaelchli Aug 25, 2021
adb5300
update test
awaelchli Aug 25, 2021
10ee7e6
Merge branch 'bugfix/iter' into thomas/fault-tolerant-rng-state
awaelchli Aug 25, 2021
59f1f4e
update test
awaelchli Aug 25, 2021
37bbcff
Revert "fix global step and epoch counters on failed checkpointing"
awaelchli Aug 25, 2021
f5f1b40
apply changes to validation too
tchaton Aug 26, 2021
499f7a5
Merge branch 'bugfix/iter' into thomas/fault-tolerant-rng-state
awaelchli Aug 26, 2021
b29a1a8
Update pytorch_lightning/trainer/supporters.py
awaelchli Aug 26, 2021
201dc27
Merge branch 'master' into thomas/fault-tolerant-rng-state
awaelchli Aug 26, 2021
fe12aee
drop dataloader_dict
tchaton Aug 26, 2021
ea33e9a
Merge branch 'thomas/fault-tolerant-rng-state' of https://github.com/…
tchaton Aug 26, 2021
d09611c
update
tchaton Aug 26, 2021
277be80
add docstring
tchaton Aug 26, 2021
0342738
add fixme
tchaton Aug 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))

- Checkpoint saving & loading extensibility:
* Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from contextlib import suppress
from typing import Optional
from typing import Any, Dict, Optional

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
Expand All @@ -40,6 +40,8 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] =
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
self.epoch_progress = Progress()
# caches the loaded dataloader state until dataloader objects are available
self._dataloader_state_dict: Dict[str, Any] = {}

@property
def current_epoch(self) -> int:
Expand Down Expand Up @@ -175,6 +177,10 @@ def on_advance_start(self) -> None:
if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch:
self.trainer.reset_train_dataloader(model)

if self._dataloader_state_dict:
self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict)
self._dataloader_state_dict = {}
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# TODO: specify the possible exception
with suppress(Exception):
# set seed for distributed sampler (enables shuffling for each epoch)
Expand Down Expand Up @@ -234,3 +240,13 @@ def should_accumulate(self) -> bool:

def teardown(self) -> None:
self.epoch_loop.teardown()

def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
return state_dict

def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
151 changes: 93 additions & 58 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@
# limitations under the License.

from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_cycle_to_next_worker_and_reset,
_find_current_worker,
_find_fast_forward_samplers,
CaptureIterableDataset,
CaptureMapDataset,
IteratorState,
MergedIteratorState,
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle
self.loader = loader
self._loader_iter = None
self.counter = 0
self.state = state
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self) -> Any:
"""
Expand All @@ -176,6 +180,7 @@ def __iter__(self) -> Any:
CycleIterator: self
"""
self.counter = 0
self.state.reset()
self._loader_iter = iter(self.loader)
return self

Expand Down Expand Up @@ -205,6 +210,12 @@ def __next__(self) -> Any:
raise StopIteration

self._loader_iter = iter(self.loader)
# if fault tolerant is enabled, we need to patch the iterator to collect the states
# before the batch gets returned.
fetcher = getattr(self.loader, "_lightning_fetcher", None)
if fetcher:
patch_dataloader_iterator(self.loader, self._loader_iter, fetcher)

return next(self._loader_iter)

finally:
Expand Down Expand Up @@ -302,11 +313,6 @@ def __len__(self) -> int:
return self._calc_num_data(self.datasets, self.mode)


class DataLoaderDict(Dict):
# behaves exactly like a dict, this is used to simplify apply_to_collection.
pass


class CombinedLoader:
"""
Combines different dataloaders and allows sampling in parallel.
Expand Down Expand Up @@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
self._iterator = None # assigned in __iter__

@staticmethod
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict:
# find next worker if multiple workers were used
state = _find_current_worker(iterator)
if isinstance(dataloader.dataset, CaptureIterableDataset):
# the sampler state dict are extracted in `CombinedLoaderIterator`
if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None:
state.update(iterator._sampler_state_dict[0])
else:
# fetch directly from fast forward sampler
state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed))
return DataLoaderDict(state)

def state_dict(self, num_batches_processed: int) -> Dict:
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict:
if isinstance(dataloader, CycleIterator):
iterator = dataloader._loader_iter
state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if state:
return asdict(state)
return {}

def state_dict(self, has_completed: bool = False) -> Dict:
"""
The state dict includes all states from wrapped dataloaders and their samplers through the
``CaptureIterableDataset`` and fast-forward samplers.

Args:
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
may have already prefetched more batches by the time a state dict is requested.
has_completed: whether the current state of data fetching is considered completed or not. If it is, the
current state gets returned, otherwise the previously cached state.
"""
if not _fault_tolerant_training():
return DataLoaderDict()

state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
if not _fault_tolerant_training() or self._iterator is None:
return {}

return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
return apply_to_collections(
self.loaders,
self._iterator.loader_iters,
(Iterator, DataLoader),
partial(self._state_dict_fn, has_completed=has_completed),
)

def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict) -> None:
# store the samplers state.
# They would be reloaded once the `CombinedIterator` as been created
# and the workers are created.
self._loaders_iter_state_dict = state_dict

def mock_reset_fn(self, *_, **__):
pass

# mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker
# and get the first batch from it
_MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset
_MultiProcessingDataLoaderIter._reset = mock_reset_fn

def on_restart(self, iterator: Iterator):
def on_restart(self, iterator: Iterator) -> None:
if not self._loaders_iter_state_dict:
return

# this happen inside the workers if any were specificied.
def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
"""Function used to reload the iterator state before once the workers are created."""

dataloader_to_iter_on = dataloader
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader

dataset = dataloader.dataset

# We reload the states before creating the workers
# The specific type of dataset will then decide if the state should be applied before or after
# spawning the workers
if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]

if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)

# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)

elif isinstance(dataset, CaptureIterableDataset):
dataset_dict = {
sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
}
dataset.load_state_dict(dataset_dict)

def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
if isinstance(dataloader.dataset, CaptureIterableDataset):
# provide the `state_dict` to the `CaptureIterableDataset`
# as it is responsible for passing down the state to associated `FastForwardSampler`
dataloader.dataset.load_state_dict(state_dict)
else:
# for `Mapping-based` dataset, the `fast_forward_sampler` was attached
# on the dataloader for simplicity
dataloader.fast_forward_sampler.load_state_dict(state_dict)
raise MisconfigurationException(
"This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
)

# We finally spawned the workers if any.
it = iter(dataloader_to_iter_on)

# cycle back the iterator to the failed worker if multiple workers were provided
iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)
# restore caching state
state = MergedIteratorState.from_state_dict(state_dict)

if isinstance(dataloader.dataset, CaptureIterableDataset):
# remove keys related to iterator
state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
# need to re-attach the state dict into the iterator for future collection.
iterator._sampler_state_dict = [state_dict]
return iterator
if isinstance(dataloader_to_iter_on, CycleIterator):
it._loader_iter.state = state
else:
it.state = state
return it

# create an un-existing token, so it doesn't activate for something else than an iterator.
class DataLoaderDict(dict):
pass

# apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
# each `Iterator` was created from the `DataLoader`.
iterator._loader_iters = apply_to_collections(
self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
self.loaders,
self._loaders_iter_state_dict,
(Iterable, DataLoaderDict),
create_loader_iters,
wrong_dtype=(Sequence, Mapping),
)

self._loaders_iter_state_dict = None

@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
Expand All @@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
self.loaders = apply_to_collection(
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
)

state.reset()

def __iter__(self) -> Any:
Expand Down
38 changes: 34 additions & 4 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial, wraps
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset

Expand Down Expand Up @@ -168,6 +172,16 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non
state[latest_worker_id] = new_state
self.latest_worker_id = latest_worker_id

@property
def sampler_states(self) -> Dict[int, Any]:
"""Returns the merged sampler states for all worker processes."""
return {0: self.state[k].sampler_state[0] for k in self.state.keys()}
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@property
def dataset_states(self) -> Dict[int, Any]:
"""Returns the merged dataset states for all worker processes."""
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}

@classmethod
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
if state_dict["represent_map_dataset"]:
Expand All @@ -188,7 +202,12 @@ def __len__(self) -> int:


class CaptureMapDataset(Dataset):
"""This class is used to capture the state from the map-based state dataset."""
"""This class is used to capture the state from the map-based state dataset.

Note:
We currently don't support restoring if we fail during the first `N = num_workers` batches, where
`num_workers` is the number of workers spawned by the dataloader.
"""

def __init__(self, dataset: Dataset) -> None:
self.dataset = dataset
Expand All @@ -202,8 +221,7 @@ def worker_id(self) -> int:
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
if self._cached_state_dict is not None:
if self.worker_id in self._cached_state_dict:
# TODO: reset random states
pass
set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
self._cached_state_dict = None

data = self.dataset[item]
Expand All @@ -227,7 +245,19 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num
self._cached_state_dict = state_dict

def _state_dict(self) -> Dict[int, Dict[str, Any]]:
return {self.worker_id: {"rng_states": {}}}
return {self.worker_id: {"rng_states": collect_rng_states()}}


def collect_rng_states() -> Dict[str, Any]:
"""Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""
return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()}
tchaton marked this conversation as resolved.
Show resolved Hide resolved


def set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
"""Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
torch.set_rng_state(rng_state_dict.get("torch"))
np.random.set_state(rng_state_dict.get("numpy"))
python_set_rng_state(rng_state_dict.get("python"))


class CaptureIterableDataset(IterableDataset):
Expand Down
Loading