diff --git a/test/test_collector.py b/test/test_collector.py index 9b0117e7486..c44bfdd6cdd 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2751,6 +2751,53 @@ def test_async(self, use_buffers): del collector +class TestCollectorRB: + def test_collector_rb_sync(self): + env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp)) + env.set_seed(0) + rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5) + collector = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + ) + torch.manual_seed(0) + + for c in collector: + assert c is None + rb.sample() + rbdata0 = rb[:].clone() + collector.shutdown() + if not env.is_closed: + env.close() + del collector, env + + env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp)) + env.set_seed(0) + rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5) + collector = SyncDataCollector( + env, RandomPolicy(env.action_spec), total_frames=256, frames_per_batch=16 + ) + torch.manual_seed(0) + + for i, c in enumerate(collector): + rb.extend(c) + torch.testing.assert_close( + rbdata0[:, : (i + 1) * 2]["observation"], rb[:]["observation"] + ) + assert c is not None + rb.sample() + + rbdata1 = rb[:].clone() + collector.shutdown() + if not env.is_closed: + env.close() + del collector, env + assert assert_allclose_td(rbdata0, rbdata1) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index be24a06e39c..6b949a12015 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -50,6 +50,7 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories +from torchrl.data import ReplayBuffer from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _do_nothing, EnvBase @@ -357,6 +358,8 @@ class SyncDataCollector(DataCollectorBase): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict + but populate the buffer instead. Defaults to ``None``. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -446,6 +449,7 @@ def __init__( interruptor=None, set_truncated: bool = False, use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, ): from torchrl.envs.batched_envs import BatchedEnvBase @@ -538,9 +542,17 @@ def __init__( self.env: EnvBase = env del env + self.replay_buffer = replay_buffer + if self.replay_buffer is not None: + if postproc is not None: + raise TypeError("postproc must be None when a replay buffer is passed.") + if use_buffers: + raise TypeError("replay_buffer is exclusive with use_buffers.") if use_buffers is None: - use_buffers = not self.env._has_dynamic_specs + use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self.closed = False if not reset_when_done: raise ValueError("reset_when_done is deprectated.") @@ -871,7 +883,15 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int: >>> out_seed = collector.set_seed(1) # out_seed = 6 """ - return self.env.set_seed(seed, static_seed=static_seed) + out = self.env.set_seed(seed, static_seed=static_seed) + return out + + def _increment_frames(self, numel): + self._frames += numel + completed = self._frames >= self.total_frames + if completed: + self.env.close() + return completed def iterator(self) -> Iterator[TensorDictBase]: """Iterates through the DataCollector. @@ -917,14 +937,15 @@ def cuda_check(tensor: torch.Tensor): for stream in streams: stack.enter_context(torch.cuda.stream(stream)) - total_frames = self.total_frames - while self._frames < self.total_frames: self._iter += 1 tensordict_out = self.rollout() - self._frames += tensordict_out.numel() - if self._frames >= total_frames: - self.env.close() + if tensordict_out is None: + # if a replay buffer is passed, there is no tensordict_out + # frames are updated within the rollout function + yield + continue + self._increment_frames(tensordict_out.numel()) if self.split_trajs: tensordict_out = split_trajectories( @@ -1053,13 +1074,18 @@ def rollout(self) -> TensorDictBase: next_data.clear_device_() self._shuttle.set("next", next_data) - if self.storing_device is not None: - tensordicts.append( - self._shuttle.to(self.storing_device, non_blocking=True) - ) - self._sync_storage() + if self.replay_buffer is not None: + self.replay_buffer.add(self._shuttle) + if self._increment_frames(self._shuttle.numel()): + return else: - tensordicts.append(self._shuttle) + if self.storing_device is not None: + tensordicts.append( + self._shuttle.to(self.storing_device, non_blocking=True) + ) + self._sync_storage() + else: + tensordicts.append(self._shuttle) # carry over collector data without messing up devices collector_data = self._shuttle.get("collector").copy() @@ -1074,6 +1100,8 @@ def rollout(self) -> TensorDictBase: self.interruptor is not None and self.interruptor.collection_stopped() ): + if self.replay_buffer is not None: + return result = self._final_rollout if self._use_buffers: try: @@ -1109,6 +1137,8 @@ def rollout(self) -> TensorDictBase: self._final_rollout.ndim - 1, out=self._final_rollout, ) + elif self.replay_buffer is not None: + return else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) result.refine_names(..., "time") diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index f785d1cedd9..e2007227127 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -216,6 +216,7 @@ class PendulumEnv(EnvBase): "render_fps": 30, } batch_locked = False + rng = None def __init__(self, td_params=None, seed=None, device=None): if td_params is None: @@ -224,7 +225,7 @@ def __init__(self, td_params=None, seed=None, device=None): super().__init__(device=device) self._make_spec(td_params) if seed is None: - seed = torch.empty((), dtype=torch.int64).random_().item() + seed = torch.empty((), dtype=torch.int64).random_(generator=self.rng).item() self.set_seed(seed) @classmethod @@ -354,7 +355,8 @@ def make_composite_from_td(td): return composite def _set_seed(self, seed: int): - rng = torch.manual_seed(seed) + rng = torch.Generator() + rng.manual_seed(seed) self.rng = rng @staticmethod