From feb4625b81dbbf4e039fe578850cd53b4e81e34d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 2 Oct 2024 11:03:20 +0100 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- tensordict/base.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index c34589024..5310a4672 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10236,17 +10236,20 @@ def to(self, *args, **kwargs) -> T: if device is not None and dtype is None and device == self.device: return result + if self.is_consolidated() and dtype is None: + return self._to_consolidated( + device=device, + pin_memory=non_blocking_pin, + num_threads=num_threads, + non_blocking=non_blocking, + ) + if non_blocking is None: sub_non_blocking = True non_blocking = False else: sub_non_blocking = non_blocking - if self.is_consolidated() and dtype is None: - return self._to_consolidated( - device=device, pin_memory=non_blocking_pin, num_threads=num_threads - ) - if convert_to_format is not None: def to(tensor): @@ -10296,7 +10299,7 @@ def to_pinmem(tensor, _to=to): self._sync_all() return result - def _to_consolidated(self, *, device, pin_memory, num_threads): + def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 @@ -10322,6 +10325,16 @@ def set_(x): result._consolidated = {"storage": storage_cast} if "metadata" in self._consolidated: result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"]) + if not non_blocking: + if device.type == "cuda": + cuda_device = device + elif storage.device.type == "cuda": + cuda_device = device + else: + cuda_device = None + if cuda_device is not None: + torch.cuda.current_stream(cuda_device).synchronize() + return result def _sync_all(self): From 9798b2d307444d961703343a0f6173e7d5f12055 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 3 Oct 2024 19:59:48 +0100 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- tensordict/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index ea2aaaa82..529146597 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10417,11 +10417,13 @@ def set_(x): result._consolidated = {"storage": storage_cast} if "metadata" in self._consolidated: result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"]) - if not non_blocking: - if device.type == "cuda": + if non_blocking in (False, None): + if device.type == "cuda" and non_blocking is False: + # sending to CUDA force sync cuda_device = device elif storage.device.type == "cuda": - cuda_device = device + # sending from cuda: need sync unless intentionally not asked for + cuda_device = storage.device.type else: cuda_device = None if cuda_device is not None: