From e93b5a6eb1bede0be161feace9276e6a8d5810ac Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Tue, 3 Dec 2024 22:57:16 +0800 Subject: [PATCH] fix(nyz): fix middleware collector env reset bug (#845) --- .../middleware/functional/collector.py | 2 ++ .../middleware/tests/mock_for_test.py | 18 +++++++++++++++--- .../middleware/tests/test_collector.py | 13 ++++++++----- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index d2fb4483b9..d87d1e0997 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -159,6 +159,8 @@ def _rollout(ctx: "OnlineRLContext"): 'step': env_info[timestep.env_id.item()]['step'], 'train_sample': env_info[timestep.env_id.item()]['train_sample'], } + # reset corresponding env info + env_info[timestep.env_id.item()] = {'time': 0., 'step': 0, 'train_sample': 0} episode_info.append(info) policy.reset([timestep.env_id.item()]) diff --git a/ding/framework/middleware/tests/mock_for_test.py b/ding/framework/middleware/tests/mock_for_test.py index 0ad88909a5..986c39e128 100644 --- a/ding/framework/middleware/tests/mock_for_test.py +++ b/ding/framework/middleware/tests/mock_for_test.py @@ -1,5 +1,6 @@ from typing import Union, Any, List, Callable, Dict, Optional from collections import namedtuple +import random import torch import treetensor.numpy as tnp from easydict import EasyDict @@ -75,6 +76,7 @@ def __init__(self) -> None: self.obs_dim = obs_dim self.closed = False self._reward_grow_indicator = 1 + self._steps = [0 for _ in range(self.env_num)] @property def ready_obs(self) -> tnp.array: @@ -90,16 +92,26 @@ def launch(self, reset_param: Optional[Dict] = None) -> None: return def reset(self, reset_param: Optional[Dict] = None) -> None: - return + self._steps = [0 for _ in range(self.env_num)] def step(self, actions: tnp.ndarray) -> List[tnp.ndarray]: timesteps = [] for i in range(self.env_num): + if self._steps[i] < 5: + done = False + elif self._steps[i] < 10: + done = random.random() > 0.5 + else: + done = True + if done: + self._steps[i] = 0 + else: + self._steps[i] += 1 timestep = dict( obs=torch.rand(self.obs_dim), reward=1.0, - done=True, - info={'eval_episode_return': self._reward_grow_indicator * 1.0}, + done=done, + info={'eval_episode_return': self._reward_grow_indicator * 1.0} if done else {}, env_id=i, ) timesteps.append(tnp.array(timestep)) diff --git a/ding/framework/middleware/tests/test_collector.py b/ding/framework/middleware/tests/test_collector.py index 13d45c3c3d..40dc6dae7f 100644 --- a/ding/framework/middleware/tests/test_collector.py +++ b/ding/framework/middleware/tests/test_collector.py @@ -22,16 +22,19 @@ def test_inferencer(): @pytest.mark.unittest def test_rolloutor(): + N = 20 ctx = OnlineRLContext() transitions = TransitionList(2) with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): policy = MockPolicy() env = MockEnv() - for _ in range(10): - inferencer(0, policy, env)(ctx) - rolloutor(policy, env, transitions)(ctx) - assert ctx.env_episode == 20 # 10 * env_num - assert ctx.env_step == 20 # 10 * env_num + i = inferencer(0, policy, env) + r = rolloutor(policy, env, transitions) + for _ in range(N): + i(ctx) + r(ctx) + assert ctx.env_step == N * 2 # N * env_num + assert ctx.env_episode >= N // 10 * 2 # N * env_num @pytest.mark.unittest