Skip to content

Commit

Permalink
Skip wrapping if serialization wrapper attached & deepcopy DataPipe i…
Browse files Browse the repository at this point in the history
…n load_state_dict (#833)

Summary:
Per title
- We don't need to attach serialization wrapper if the last DataPipe has been `_DataPipeSerializationWrapper`
- When we `load_state_dict`, we still need a copy of `self.datapipe` as it has been done in `__init__` function. We should either do deepcoy to `self._datapipe_before_reading_service_adapt` or skip deepcopy to `self._datapipe_before_reading_service_adapt` at those two places.

Pull Request resolved: #833

Reviewed By: NivekT

Differential Revision: D40399573

Pulled By: ejguan

fbshipit-source-id: 16fc80bd005a4b8671d48780c7aea8f164bccba8
  • Loading branch information
ejguan authored and facebook-github-bot committed Oct 17, 2022
1 parent 63e5f2f commit 2fd9f98
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper
from torch.utils.data.datapipes.datapipe import (
_DataPipeSerializationWrapper,
_IterDataPipeSerializationWrapper,
_MapDataPipeSerializationWrapper,
)

from torch.utils.data.graph import DataPipe
from torchdata.dataloader2.adapter import Adapter
Expand Down Expand Up @@ -178,10 +182,11 @@ def _wrap_and_copy_dp(datapipe: DataPipe):
"""
wrapped_dp: DataPipe = datapipe
if isinstance(datapipe, IterDataPipe):
wrapped_dp = _IterDataPipeSerializationWrapper(datapipe)
elif isinstance(datapipe, MapDataPipe):
wrapped_dp = _MapDataPipeSerializationWrapper(datapipe)
if not isinstance(datapipe, _DataPipeSerializationWrapper):
if isinstance(datapipe, IterDataPipe):
wrapped_dp = _IterDataPipeSerializationWrapper(datapipe)
elif isinstance(datapipe, MapDataPipe):
wrapped_dp = _MapDataPipeSerializationWrapper(datapipe)
return DataLoader2._copy(wrapped_dp)

def shutdown(self) -> None:
Expand Down Expand Up @@ -269,4 +274,4 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
if self.datapipe_adapter_fns is not None:
for adapter_fn in self.datapipe_adapter_fns:
self.datapipe = adapter_fn(self.datapipe)
self._datapipe_before_reading_service_adapt = self.datapipe
self._datapipe_before_reading_service_adapt = self._copy(self.datapipe)

0 comments on commit 2fd9f98

Please sign in to comment.