Skip to content

Commit 4f013a8

Browse files
committed
[Test] Fix flaky parallel test
ghstack-source-id: aecea30 Pull-Request: #3204
1 parent e954e97 commit 4f013a8

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

test/test_env.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pickle
1414
import random
1515
import re
16+
import time
1617
from collections import defaultdict
1718
from functools import partial
1819
from sys import platform
@@ -3715,26 +3716,39 @@ def test_batched_nondynamic(self, penv):
37153716
use_buffers=True,
37163717
mp_start_method=mp_ctx if penv is ParallelEnv else None,
37173718
)
3718-
env_buffers.set_seed(0)
3719-
torch.manual_seed(0)
3720-
rollout_buffers = env_buffers.rollout(
3721-
20, return_contiguous=True, break_when_any_done=False
3722-
)
3723-
del env_buffers
3719+
try:
3720+
env_buffers.set_seed(0)
3721+
torch.manual_seed(0)
3722+
rollout_buffers = env_buffers.rollout(
3723+
20, return_contiguous=True, break_when_any_done=False
3724+
)
3725+
finally:
3726+
env_buffers.close(raise_if_closed=False)
3727+
del env_buffers
37243728
gc.collect()
3729+
# Add a small delay to allow multiprocessing resource_sharer threads
3730+
# to fully clean up before creating the next environment. This prevents
3731+
# a race condition where the old resource_sharer service thread is still
3732+
# active when the new environment starts, causing a deadlock.
3733+
# See: https://bugs.python.org/issue30289
3734+
if penv is ParallelEnv:
3735+
time.sleep(0.1)
37253736

37263737
env_no_buffers = penv(
37273738
3,
37283739
lambda: GymEnv(CARTPOLE_VERSIONED(), device=None),
37293740
use_buffers=False,
37303741
mp_start_method=mp_ctx if penv is ParallelEnv else None,
37313742
)
3732-
env_no_buffers.set_seed(0)
3733-
torch.manual_seed(0)
3734-
rollout_no_buffers = env_no_buffers.rollout(
3735-
20, return_contiguous=True, break_when_any_done=False
3736-
)
3737-
del env_no_buffers
3743+
try:
3744+
env_no_buffers.set_seed(0)
3745+
torch.manual_seed(0)
3746+
rollout_no_buffers = env_no_buffers.rollout(
3747+
20, return_contiguous=True, break_when_any_done=False
3748+
)
3749+
finally:
3750+
env_no_buffers.close(raise_if_closed=False)
3751+
del env_no_buffers
37383752
gc.collect()
37393753
assert_allclose_td(rollout_buffers, rollout_no_buffers)
37403754

torchrl/data/datasets/d4rl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def _get_dataset_direct(self, name, env_kwargs):
279279
# so we need to ensure we're using the gym backend
280280
with set_gym_backend("gym"):
281281
import gym
282+
282283
env = GymWrapper(gym.make(name))
283284
with tempfile.TemporaryDirectory() as tmpdir:
284285
os.environ["D4RL_DATASET_DIR"] = tmpdir
@@ -358,6 +359,7 @@ def _get_dataset_from_env(self, name, env_kwargs):
358359
# so we need to ensure we're using the gym backend
359360
with set_gym_backend("gym"), tempfile.TemporaryDirectory() as tmpdir:
360361
import gym
362+
361363
os.environ["D4RL_DATASET_DIR"] = tmpdir
362364
env = GymWrapper(gym.make(name))
363365
dataset = make_tensordict(

0 commit comments

Comments
 (0)