Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 12, 2024
1 parent 073483f commit d9d9225
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,29 +272,34 @@ def __init__(
call_sync = non_blocking is None
if call_sync:
_device_recorder.mark()
self._device = device
try:
self._device = device

if source is None:
source = {}
if not isinstance(source, (TensorDictBase, dict)):
raise ValueError(
"A TensorDict source is expected to be a TensorDictBase "
f"sub-type or a dictionary, found type(source)={type(source)}."
)
self._batch_size = self._parse_batch_size(source, batch_size)
# TODO: this breaks when stacking tensorclasses with dynamo
if not is_dynamo_compiling():
self.names = names

for key, value in source.items():
self.set(key, value, non_blocking=sub_non_blocking)
if call_sync:
if _device_recorder.has_transfer():
self._sync_all()
_device_recorder.unmark()
if source is None:
source = {}
if not isinstance(source, (TensorDictBase, dict)):
raise ValueError(
"A TensorDict source is expected to be a TensorDictBase "
f"sub-type or a dictionary, found type(source)={type(source)}."
)
self._batch_size = self._parse_batch_size(source, batch_size)
# TODO: this breaks when stacking tensorclasses with dynamo
if not is_dynamo_compiling():
self.names = names

if lock:
self.lock_()
for key, value in source.items():
self.set(key, value, non_blocking=sub_non_blocking)
if call_sync:
if _device_recorder.has_transfer():
self._sync_all()
_device_recorder.unmark()
call_sync = False

if lock:
self.lock_()
finally:
if call_sync:
_device_recorder.unmark()

@classmethod
def _new_unsafe(
Expand Down

0 comments on commit d9d9225

Please sign in to comment.