Skip to content

Commit

Permalink
[BugFix] Fix td device sync when error is raised
Browse files Browse the repository at this point in the history
ghstack-source-id: d0e810c71ca1c9945561ca5a9e71cb71445095e4
Pull Request resolved: #988
  • Loading branch information
vmoens committed Sep 12, 2024
1 parent d4f8eee commit 2fa8d7a
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,29 +272,34 @@ def __init__(
call_sync = non_blocking is None
if call_sync:
_device_recorder.mark()
self._device = device
try:
self._device = device

if source is None:
source = {}
if not isinstance(source, (TensorDictBase, dict)):
raise ValueError(
"A TensorDict source is expected to be a TensorDictBase "
f"sub-type or a dictionary, found type(source)={type(source)}."
)
self._batch_size = self._parse_batch_size(source, batch_size)
# TODO: this breaks when stacking tensorclasses with dynamo
if not is_dynamo_compiling():
self.names = names

for key, value in source.items():
self.set(key, value, non_blocking=sub_non_blocking)
if call_sync:
if _device_recorder.has_transfer():
self._sync_all()
_device_recorder.unmark()
if source is None:
source = {}
if not isinstance(source, (TensorDictBase, dict)):
raise ValueError(
"A TensorDict source is expected to be a TensorDictBase "
f"sub-type or a dictionary, found type(source)={type(source)}."
)
self._batch_size = self._parse_batch_size(source, batch_size)
# TODO: this breaks when stacking tensorclasses with dynamo
if not is_dynamo_compiling():
self.names = names

if lock:
self.lock_()
for key, value in source.items():
self.set(key, value, non_blocking=sub_non_blocking)
if call_sync:
if _device_recorder.has_transfer():
self._sync_all()
_device_recorder.unmark()
call_sync = False

if lock:
self.lock_()
finally:
if call_sync:
_device_recorder.unmark()

@classmethod
def _new_unsafe(
Expand Down

0 comments on commit 2fa8d7a

Please sign in to comment.