Skip to content

Commit 5d7095f

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents fafd6a8 + 4f013a8 commit 5d7095f

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

test/test_env.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pickle
1414
import random
1515
import re
16+
import time
1617
from collections import defaultdict
1718
from functools import partial
1819
from sys import platform
@@ -3715,26 +3716,39 @@ def test_batched_nondynamic(self, penv):
37153716
use_buffers=True,
37163717
mp_start_method=mp_ctx if penv is ParallelEnv else None,
37173718
)
3718-
env_buffers.set_seed(0)
3719-
torch.manual_seed(0)
3720-
rollout_buffers = env_buffers.rollout(
3721-
20, return_contiguous=True, break_when_any_done=False
3722-
)
3723-
del env_buffers
3719+
try:
3720+
env_buffers.set_seed(0)
3721+
torch.manual_seed(0)
3722+
rollout_buffers = env_buffers.rollout(
3723+
20, return_contiguous=True, break_when_any_done=False
3724+
)
3725+
finally:
3726+
env_buffers.close(raise_if_closed=False)
3727+
del env_buffers
37243728
gc.collect()
3729+
# Add a small delay to allow multiprocessing resource_sharer threads
3730+
# to fully clean up before creating the next environment. This prevents
3731+
# a race condition where the old resource_sharer service thread is still
3732+
# active when the new environment starts, causing a deadlock.
3733+
# See: https://bugs.python.org/issue30289
3734+
if penv is ParallelEnv:
3735+
time.sleep(0.1)
37253736

37263737
env_no_buffers = penv(
37273738
3,
37283739
lambda: GymEnv(CARTPOLE_VERSIONED(), device=None),
37293740
use_buffers=False,
37303741
mp_start_method=mp_ctx if penv is ParallelEnv else None,
37313742
)
3732-
env_no_buffers.set_seed(0)
3733-
torch.manual_seed(0)
3734-
rollout_no_buffers = env_no_buffers.rollout(
3735-
20, return_contiguous=True, break_when_any_done=False
3736-
)
3737-
del env_no_buffers
3743+
try:
3744+
env_no_buffers.set_seed(0)
3745+
torch.manual_seed(0)
3746+
rollout_no_buffers = env_no_buffers.rollout(
3747+
20, return_contiguous=True, break_when_any_done=False
3748+
)
3749+
finally:
3750+
env_no_buffers.close(raise_if_closed=False)
3751+
del env_no_buffers
37383752
gc.collect()
37393753
assert_allclose_td(rollout_buffers, rollout_no_buffers)
37403754

test/test_libs.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import collections
8+
import copy
89
import functools
910
import gc
1011
import importlib.util
@@ -2811,14 +2812,27 @@ def test_vmas_seeding(self, scenario_name):
28112812
final_seed = []
28122813
tdreset = []
28132814
tdrollout = []
2814-
for _ in range(2):
2815-
env = VmasEnv(
2815+
rollout_length = 10
2816+
2817+
def create_env():
2818+
return VmasEnv(
28162819
scenario=scenario_name,
28172820
num_envs=4,
28182821
)
2822+
2823+
env = create_env()
2824+
td_actions = [env.action_spec.rand() for _ in range(rollout_length)]
2825+
2826+
for _ in range(2):
2827+
env = create_env()
2828+
td_actions_buffer = copy.deepcopy(td_actions)
2829+
2830+
def policy(td, actions=td_actions_buffer):
2831+
return actions.pop(0)
2832+
28192833
final_seed.append(env.set_seed(0))
28202834
tdreset.append(env.reset())
2821-
tdrollout.append(env.rollout(max_steps=10))
2835+
tdrollout.append(env.rollout(max_steps=rollout_length, policy=policy))
28222836
env.close()
28232837
del env
28242838
assert final_seed[0] == final_seed[1]

torchrl/data/datasets/d4rl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def _get_dataset_direct(self, name, env_kwargs):
279279
# so we need to ensure we're using the gym backend
280280
with set_gym_backend("gym"):
281281
import gym
282+
282283
env = GymWrapper(gym.make(name))
283284
with tempfile.TemporaryDirectory() as tmpdir:
284285
os.environ["D4RL_DATASET_DIR"] = tmpdir
@@ -358,6 +359,7 @@ def _get_dataset_from_env(self, name, env_kwargs):
358359
# so we need to ensure we're using the gym backend
359360
with set_gym_backend("gym"), tempfile.TemporaryDirectory() as tmpdir:
360361
import gym
362+
361363
os.environ["D4RL_DATASET_DIR"] = tmpdir
362364
env = GymWrapper(gym.make(name))
363365
dataset = make_tensordict(

0 commit comments

Comments
 (0)