99import pytest
1010import torch
1111import torch .nn as nn
12+ from mocking_classes import ContinuousActionVecMockEnv
1213from tensordict import TensorDict
1314from tensordict .nn import TensorDictModule
1415from torch import multiprocessing as mp
1516from torchrl .collectors import MultiSyncDataCollector , SyncDataCollector
16- from torchrl .envs import GymEnv
1717from torchrl .weight_update .weight_sync_schemes import (
1818 _resolve_model ,
1919 MPTransport ,
@@ -274,7 +274,7 @@ def test_no_weight_sync_scheme(self):
274274class TestCollectorIntegration :
275275 @pytest .fixture
276276 def simple_env (self ):
277- return GymEnv ( "CartPole-v1" )
277+ return ContinuousActionVecMockEnv ( )
278278
279279 @pytest .fixture
280280 def simple_policy (self , simple_env ):
@@ -291,7 +291,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy):
291291 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
292292
293293 collector = SyncDataCollector (
294- create_env_fn = lambda : GymEnv ( "CartPole-v1" ) ,
294+ create_env_fn = ContinuousActionVecMockEnv ,
295295 policy = simple_policy ,
296296 frames_per_batch = 64 ,
297297 total_frames = 128 ,
@@ -316,8 +316,8 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy):
316316
317317 collector = MultiSyncDataCollector (
318318 create_env_fn = [
319- lambda : GymEnv ( "CartPole-v1" ) ,
320- lambda : GymEnv ( "CartPole-v1" ) ,
319+ ContinuousActionVecMockEnv ,
320+ ContinuousActionVecMockEnv ,
321321 ],
322322 policy = simple_policy ,
323323 frames_per_batch = 64 ,
@@ -343,8 +343,8 @@ def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy):
343343
344344 collector = MultiSyncDataCollector (
345345 create_env_fn = [
346- lambda : GymEnv ( "CartPole-v1" ) ,
347- lambda : GymEnv ( "CartPole-v1" ) ,
346+ ContinuousActionVecMockEnv ,
347+ ContinuousActionVecMockEnv ,
348348 ],
349349 policy = simple_policy ,
350350 frames_per_batch = 64 ,
@@ -369,7 +369,7 @@ def test_collector_no_weight_sync(self, simple_policy):
369369 scheme = NoWeightSyncScheme ()
370370
371371 collector = SyncDataCollector (
372- create_env_fn = lambda : GymEnv ( "CartPole-v1" ) ,
372+ create_env_fn = ContinuousActionVecMockEnv ,
373373 policy = simple_policy ,
374374 frames_per_batch = 64 ,
375375 total_frames = 128 ,
@@ -385,7 +385,7 @@ def test_collector_no_weight_sync(self, simple_policy):
385385
386386class TestMultiModelUpdates :
387387 def test_multi_model_state_dict_updates (self ):
388- env = GymEnv ( "CartPole-v1" )
388+ env = ContinuousActionVecMockEnv ( )
389389
390390 policy = TensorDictModule (
391391 nn .Linear (
@@ -407,7 +407,7 @@ def test_multi_model_state_dict_updates(self):
407407 }
408408
409409 collector = SyncDataCollector (
410- create_env_fn = lambda : GymEnv ( "CartPole-v1" ) ,
410+ create_env_fn = ContinuousActionVecMockEnv ,
411411 policy = policy ,
412412 frames_per_batch = 64 ,
413413 total_frames = 128 ,
@@ -438,7 +438,7 @@ def test_multi_model_state_dict_updates(self):
438438 env .close ()
439439
440440 def test_multi_model_tensordict_updates (self ):
441- env = GymEnv ( "CartPole-v1" )
441+ env = ContinuousActionVecMockEnv ( )
442442
443443 policy = TensorDictModule (
444444 nn .Linear (
@@ -460,7 +460,7 @@ def test_multi_model_tensordict_updates(self):
460460 }
461461
462462 collector = SyncDataCollector (
463- create_env_fn = lambda : GymEnv ( "CartPole-v1" ) ,
463+ create_env_fn = ContinuousActionVecMockEnv ,
464464 policy = policy ,
465465 frames_per_batch = 64 ,
466466 total_frames = 128 ,
0 commit comments