diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index ecc5e0222..4f739f047 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -176,7 +176,7 @@ def reward_metric(rews): def watch( args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None ) -> None: - env = get_env() + env = DummyVectorEnv([get_env]) policy.eval() [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 33072b50f..da1285a40 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -268,7 +268,7 @@ def reward_metric(rews): def watch( args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None ) -> None: - env = get_env() + env = DummyVectorEnv([get_env]) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render)