|
73 | 73 | from torchrl.data.datasets.roboset import RobosetExperienceReplay |
74 | 74 | from torchrl.data.datasets.vd4rl import VD4RLExperienceReplay |
75 | 75 | from torchrl.data.replay_buffers import SamplerWithoutReplacement |
| 76 | +from torchrl.data.replay_buffers.samplers import SliceSampler |
| 77 | +from torchrl.data.replay_buffers.storages import LazyTensorStorage |
76 | 78 | from torchrl.data.utils import CloudpickleWrapper |
77 | 79 | from torchrl.envs import ( |
78 | 80 | CatTensors, |
|
82 | 84 | EnvCreator, |
83 | 85 | RemoveEmptySpecs, |
84 | 86 | RenameTransform, |
| 87 | + StepCounter, |
85 | 88 | ) |
86 | 89 | from torchrl.envs.batched_envs import SerialEnv |
87 | 90 | from torchrl.envs.libs.brax import _has_brax, BraxEnv, BraxWrapper |
@@ -2790,7 +2793,7 @@ class TestVmas: |
2790 | 2793 | @pytest.mark.parametrize("scenario_name", VmasWrapper.available_envs) |
2791 | 2794 | @pytest.mark.parametrize("continuous_actions", [True, False]) |
2792 | 2795 | def test_all_vmas_scenarios(self, scenario_name, continuous_actions): |
2793 | | - |
| 2796 | + |
2794 | 2797 | env = VmasEnv( |
2795 | 2798 | scenario=scenario_name, |
2796 | 2799 | continuous_actions=continuous_actions, |
@@ -3455,6 +3458,8 @@ def test_d4rl_dummy(self, task): |
3455 | 3458 | @pytest.mark.parametrize("split_trajs", [True, False]) |
3456 | 3459 | @pytest.mark.parametrize("from_env", [True, False]) |
3457 | 3460 | def test_dataset_build(self, task, split_trajs, from_env): |
| 3461 | + import d4rl # noqa: F401 |
| 3462 | + |
3458 | 3463 | t0 = time.time() |
3459 | 3464 | data = D4RLExperienceReplay( |
3460 | 3465 | task, split_trajs=split_trajs, from_env=from_env, batch_size=2 |
@@ -5144,6 +5149,17 @@ def test_isaaclab(self, env): |
5144 | 5149 | env.check_env_specs(break_when_any_done="both") |
5145 | 5150 | torchrl_logger.info("Check succeeded!") |
5146 | 5151 |
|
| 5152 | + def test_isaaclab_rb(self, env): |
| 5153 | + env = env.append_transform(StepCounter()) |
| 5154 | + rb = ReplayBuffer( |
| 5155 | + storage=LazyTensorStorage(50, ndim=2), sampler=SliceSampler(num_slices=5) |
| 5156 | + ) |
| 5157 | + rb.extend(env.rollout(20)) |
| 5158 | + # check that rb["step_count"].flatten() is made of sequences of 4 consecutive numbers |
| 5159 | + flat_ranges = rb["step_count"].flatten() % 4 |
| 5160 | + arange = torch.arange(flat_ranges.numel(), device=flat_ranges.device) % 4 |
| 5161 | + assert (flat_ranges == arange).all() |
| 5162 | + |
5147 | 5163 | def test_isaac_collector(self, env): |
5148 | 5164 | col = SyncDataCollector( |
5149 | 5165 | env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000 |
|
0 commit comments