Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Non-blocking for consolidated TD #1020

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10328,17 +10328,20 @@ def to(self, *args, **kwargs) -> T:
if device is not None and dtype is None and device == self.device:
return result

if self.is_consolidated() and dtype is None:
return self._to_consolidated(
device=device,
pin_memory=non_blocking_pin,
num_threads=num_threads,
non_blocking=non_blocking,
)

if non_blocking is None:
sub_non_blocking = True
non_blocking = False
else:
sub_non_blocking = non_blocking

if self.is_consolidated() and dtype is None:
return self._to_consolidated(
device=device, pin_memory=non_blocking_pin, num_threads=num_threads
)

if convert_to_format is not None:

def to(tensor):
Expand Down Expand Up @@ -10388,7 +10391,7 @@ def to_pinmem(tensor, _to=to):
self._sync_all()
return result

def _to_consolidated(self, *, device, pin_memory, num_threads):
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
if num_threads is None:
# unspecified num_threads should mean 0
num_threads = 0
Expand All @@ -10414,6 +10417,18 @@ def set_(x):
result._consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"])
if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
torch.cuda.current_stream(cuda_device).synchronize()

return result

def _sync_all(self):
Expand Down
Loading