Skip to content

Commit 963fdd4

Browse files
committed
[BugFix] Fix tests
ghstack-source-id: 5704ab4 Pull-Request: #3218
1 parent 92c20cd commit 963fdd4

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

test/test_weightsync.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
from mocking_classes import ContinuousActionVecMockEnv
1213
from tensordict import TensorDict
1314
from tensordict.nn import TensorDictModule
1415
from torch import multiprocessing as mp
1516
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
16-
from torchrl.envs import GymEnv
1717
from torchrl.weight_update.weight_sync_schemes import (
1818
_resolve_model,
1919
MPTransport,
@@ -274,7 +274,7 @@ def test_no_weight_sync_scheme(self):
274274
class TestCollectorIntegration:
275275
@pytest.fixture
276276
def simple_env(self):
277-
return GymEnv("CartPole-v1")
277+
return ContinuousActionVecMockEnv()
278278

279279
@pytest.fixture
280280
def simple_policy(self, simple_env):
@@ -291,7 +291,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy):
291291
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
292292

293293
collector = SyncDataCollector(
294-
create_env_fn=lambda: GymEnv("CartPole-v1"),
294+
create_env_fn=ContinuousActionVecMockEnv,
295295
policy=simple_policy,
296296
frames_per_batch=64,
297297
total_frames=128,
@@ -316,8 +316,8 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy):
316316

317317
collector = MultiSyncDataCollector(
318318
create_env_fn=[
319-
lambda: GymEnv("CartPole-v1"),
320-
lambda: GymEnv("CartPole-v1"),
319+
ContinuousActionVecMockEnv,
320+
ContinuousActionVecMockEnv,
321321
],
322322
policy=simple_policy,
323323
frames_per_batch=64,
@@ -343,8 +343,8 @@ def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy):
343343

344344
collector = MultiSyncDataCollector(
345345
create_env_fn=[
346-
lambda: GymEnv("CartPole-v1"),
347-
lambda: GymEnv("CartPole-v1"),
346+
ContinuousActionVecMockEnv,
347+
ContinuousActionVecMockEnv,
348348
],
349349
policy=simple_policy,
350350
frames_per_batch=64,
@@ -369,7 +369,7 @@ def test_collector_no_weight_sync(self, simple_policy):
369369
scheme = NoWeightSyncScheme()
370370

371371
collector = SyncDataCollector(
372-
create_env_fn=lambda: GymEnv("CartPole-v1"),
372+
create_env_fn=ContinuousActionVecMockEnv,
373373
policy=simple_policy,
374374
frames_per_batch=64,
375375
total_frames=128,
@@ -385,7 +385,7 @@ def test_collector_no_weight_sync(self, simple_policy):
385385

386386
class TestMultiModelUpdates:
387387
def test_multi_model_state_dict_updates(self):
388-
env = GymEnv("CartPole-v1")
388+
env = ContinuousActionVecMockEnv()
389389

390390
policy = TensorDictModule(
391391
nn.Linear(
@@ -407,7 +407,7 @@ def test_multi_model_state_dict_updates(self):
407407
}
408408

409409
collector = SyncDataCollector(
410-
create_env_fn=lambda: GymEnv("CartPole-v1"),
410+
create_env_fn=ContinuousActionVecMockEnv,
411411
policy=policy,
412412
frames_per_batch=64,
413413
total_frames=128,
@@ -438,7 +438,7 @@ def test_multi_model_state_dict_updates(self):
438438
env.close()
439439

440440
def test_multi_model_tensordict_updates(self):
441-
env = GymEnv("CartPole-v1")
441+
env = ContinuousActionVecMockEnv()
442442

443443
policy = TensorDictModule(
444444
nn.Linear(
@@ -460,7 +460,7 @@ def test_multi_model_tensordict_updates(self):
460460
}
461461

462462
collector = SyncDataCollector(
463-
create_env_fn=lambda: GymEnv("CartPole-v1"),
463+
create_env_fn=ContinuousActionVecMockEnv,
464464
policy=policy,
465465
frames_per_batch=64,
466466
total_frames=128,

torchrl/envs/libs/gym.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,12 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811
12551255

12561256
@property
12571257
def lib(self) -> ModuleType:
1258-
return gym_backend()
1258+
gym = gym_backend()
1259+
if gym is None:
1260+
raise RuntimeError(
1261+
"Gym backend is not available. Please install gym or gymnasium."
1262+
)
1263+
return gym
12591264

12601265
def _set_seed(self, seed: int | None) -> None: # noqa: F811
12611266
if self._seed_calls_reset is None:

0 commit comments

Comments
 (0)