Skip to content

Commit

Permalink
[Feature] Send info dict to the storage device in RBs
Browse files Browse the repository at this point in the history
ghstack-source-id: 4ed60d649b17f96b49f90d234e679937c60a3c32
Pull Request resolved: #2527
  • Loading branch information
vmoens committed Oct 29, 2024
1 parent 3e4b292 commit 304e802
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
"for a proper usage of the batch-size arguments."
)
if not self._prefetch:
ret = self._sample(batch_size)
result = self._sample(batch_size)
else:
with self._futures_lock:
while (
Expand All @@ -722,11 +722,15 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
) or not len(self._prefetch_queue):
fut = self._prefetch_executor.submit(self._sample, batch_size)
self._prefetch_queue.append(fut)
ret = self._prefetch_queue.popleft().result()
result = self._prefetch_queue.popleft().result()

if return_info:
return ret
return ret[0]
out, info = result
if getattr(self.storage, "device", None) is not None:
device = self.storage.device
info = tree_map(lambda x: x.to(device) if hasattr(x, "to") else x, info)
return out, info
return result[0]

def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self._sampler.mark_update(index, storage=self._storage)
Expand Down

0 comments on commit 304e802

Please sign in to comment.