Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Feb 16, 2023
1 parent 8042a1a commit 956748c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 27 deletions.
1 change: 1 addition & 0 deletions torchdata/datapipes/iter/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ${init_base}
from .util.decompressor import CompressionType
from torchdata._constants import default_timeout_in_s
from torchdata.datapipes.map import MapDataPipe
from torchdata.datapipes.iter.util.prefetcher import pin_memory_fn
from torch.utils.data import DataChunk, IterableDataset, default_collate
from torch.utils.data.datapipes._typing import _DataPipeMeta
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
Expand Down
30 changes: 3 additions & 27 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import collections
import threading
import time

Expand All @@ -15,6 +14,7 @@

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.utils import pin_memory_fn

PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fulfillment checks
CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availability in buffer
Expand Down Expand Up @@ -143,31 +143,6 @@ def resume(self):
self.prefetch_data.run_prefetcher = True


def pin_memory_fn(data, device=None):
if hasattr(data, "pin_memory"):
return data.pin_memory(device)
elif isinstance(data, torch.Tensor):
return data.pin_memory(device)
elif isinstance(data, str):
return data
elif isinstance(data, collections.abc.Mapping):
pinned_data = {k: pin_memory_fn(sample, device) for k, sample in data.items()}
try:
return type(data)(**pinned_data)
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return pinned_data
elif isinstance(data, collections.abc.Sequence):
pinned_data = [pin_memory_fn(sample, device) for sample in data] # type: ignore[assignment]
try:
type(data)(*pinned_data)
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return pinned_data
else:
return data


@functional_datapipe("pin_memory")
class PinMemoryIterDataPipe(PrefetcherIterDataPipe):
r"""
Expand All @@ -189,7 +164,8 @@ class PinMemoryIterDataPipe(PrefetcherIterDataPipe):
def __init__(self, source_datapipe, device=None, pin_memory_fn=pin_memory_fn):
if not torch.cuda.is_available():
raise RuntimeError("``pin_memory`` can only be used when CUDA is available.")
super().__init__(source_datapipe, 1)
# TODO: Add support for dynamic buffer based on the available size of pinned memory
super().__init__(source_datapipe, buffer_size=2)
if device is None:
device = torch.cuda.current_device()
self.device = device
Expand Down
11 changes: 11 additions & 0 deletions torchdata/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchdata.utils.pin_memory import pin_memory_fn

__all__ = ["pin_memory_fn"]

assert __all__ == sorted(__all__)
34 changes: 34 additions & 0 deletions torchdata/utils/pin_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import collections

import torch


def pin_memory_fn(data, device=None):
if hasattr(data, "pin_memory"):
return data.pin_memory(device)
elif isinstance(data, torch.Tensor):
return data.pin_memory(device)
elif isinstance(data, (str, bytes)):
return data
elif isinstance(data, collections.abc.Mapping):
pinned_data = {k: pin_memory_fn(sample, device) for k, sample in data.items()}
try:
return type(data)(**pinned_data)
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return pinned_data
elif isinstance(data, collections.abc.Sequence):
pinned_data = [pin_memory_fn(sample, device) for sample in data] # type: ignore[assignment]
try:
type(data)(*pinned_data)
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return pinned_data
else:
return data

0 comments on commit 956748c

Please sign in to comment.