Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can evaluate and save best model with callbacks #117

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/cartpole/callbacks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@ callbacks:
- id: "CheckpointCallback"
args: {
"save_freq": 500,
"save_path": "./checkpoints/",
"save_path": "./results/checkpoints/",
"name_prefix": "ppo",
"save_replay_buffer": True
}
- id: "EvalCallback"
args: {
"eval_env": {"id": "CartPole-v1","env_num":4},
"n_eval_episodes": 4,
"eval_freq":500,
"log_path": "./results/eval_log_path",
"best_model_save_path": "./results/best_model/",
"deterministic": True,
"render": True,
"asynchronous": True,
}
2 changes: 1 addition & 1 deletion examples/cartpole/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def evaluation(agent):

if __name__ == "__main__":
agent = train()
# evaluation(agent)
evaluation(agent)
1 change: 1 addition & 0 deletions openrl/envs/common/build_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _make_env() -> Env:
env = wrapper(env)
else:
raise NotImplementedError

return env

return _make_env
Expand Down
4 changes: 2 additions & 2 deletions openrl/envs/common/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Callable, Optional

import gymnasium as gym
from gymnasium import Env

import openrl
from openrl.envs.vec_env import (
AsyncVectorEnv,
BaseVecEnv,
RewardWrapper,
SyncVectorEnv,
VecMonitorWrapper,
Expand All @@ -40,7 +40,7 @@ def make(
render_mode: Optional[str] = None,
make_custom_envs: Optional[Callable] = None,
**kwargs,
) -> Env:
) -> BaseVecEnv:
if render_mode in [None, "human", "rgb_array"]:
convert_render_mode = render_mode
elif render_mode in ["group_human", "group_rgb_array"]:
Expand Down
7 changes: 1 addition & 6 deletions openrl/envs/connect3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ def make_connect3_envs(
render_mode: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> List[Callable[[], Env]]:
from openrl.envs.wrappers import (
RemoveTruncated,
Single2MultiAgentWrapper,
)
from openrl.envs.wrappers import RemoveTruncated, Single2MultiAgentWrapper

env_wrappers = [
Single2MultiAgentWrapper,
Expand All @@ -48,5 +45,3 @@ def make_connect3_envs(
**kwargs,
)
return env_fns


87 changes: 49 additions & 38 deletions openrl/envs/connect3/connect3_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Optional

import numpy as np
import gymnasium as gym
import numpy as np
from gymnasium import Env, spaces


Expand All @@ -12,24 +12,31 @@ def make(
) -> Env:
# create Connect3 environment from id
if id == "connect3":
env = Connect3Env(env_name=id,
args=kwargs)
env = Connect3Env(env_name=id, args=kwargs)

return env


def check_if_win(state, check_row_pos, check_col_pos, all_args):
def check_if_win_direction(now_state, direction, row_pos, col_pos, args):
def check_if_valid(x_pos, y_pos):
return x_pos >= 0 and x_pos <= (args["row"] - 1) and y_pos >= 0 and y_pos <= (args["col"] - 1)
return (
x_pos >= 0
and x_pos <= (args["row"] - 1)
and y_pos >= 0
and y_pos <= (args["col"] - 1)
)

check_who = now_state[row_pos][col_pos]
counting = 1
bias_num = 1
while True:
new_row_pos = row_pos + bias_num * direction[0]
new_col_pos = col_pos + bias_num * direction[1]
if not check_if_valid(new_row_pos, new_col_pos) or now_state[new_row_pos][new_col_pos] != check_who:
if (
not check_if_valid(new_row_pos, new_col_pos)
or now_state[new_row_pos][new_col_pos] != check_who
):
break
else:
counting += 1
Expand All @@ -38,7 +45,10 @@ def check_if_valid(x_pos, y_pos):
while True:
new_row_pos = row_pos + bias_num * direction[0]
new_col_pos = col_pos + bias_num * direction[1]
if not check_if_valid(new_row_pos, new_col_pos) or now_state[new_row_pos][new_col_pos] != check_who:
if (
not check_if_valid(new_row_pos, new_col_pos)
or now_state[new_row_pos][new_col_pos] != check_who
):
break
else:
counting += 1
Expand All @@ -50,7 +60,9 @@ def check_if_valid(x_pos, y_pos):

directions = [(1, 0), (0, 1), (1, 1), (1, -1)] # 横 竖 右下 右上
for direction in directions:
if check_if_win_direction(state, direction, check_row_pos, check_col_pos, all_args):
if check_if_win_direction(
state, direction, check_row_pos, check_col_pos, all_args
):
return True
return False

Expand All @@ -69,36 +81,40 @@ def __init__(self, env_name, args):

obs_space_low = np.zeros(obs_space_dim) - 1e6
obs_space_high = np.zeros(obs_space_dim) + 1e6
obs_space_type = 'float64'
obs_space_type = "float64"
sobs_space_dim = obs_space_dim * args["num_agents"]
sobs_space_low = np.zeros(sobs_space_dim) - 1e6
sobs_space_high = np.zeros(sobs_space_dim) + 1e6

if args["num_agents"] > 1:
self.action_space = [spaces.Discrete(self.action_size) for _ in range(args["num_agents"])]
self.observation_space = [spaces.Box(
low=obs_space_low,
high=obs_space_high,
dtype=obs_space_type) for _ in range(args["num_agents"])]
self.share_observation_space = [spaces.Box(
low=sobs_space_low,
high=sobs_space_high,
dtype=obs_space_type) for _ in range(args["num_agents"])]
self.action_space = [
spaces.Discrete(self.action_size) for _ in range(args["num_agents"])
]
self.observation_space = [
spaces.Box(low=obs_space_low, high=obs_space_high, dtype=obs_space_type)
for _ in range(args["num_agents"])
]
self.share_observation_space = [
spaces.Box(
low=sobs_space_low, high=sobs_space_high, dtype=obs_space_type
)
for _ in range(args["num_agents"])
]
else:
self.action_space = spaces.Discrete(self.action_size)
self.observation_space = spaces.Box(
low=obs_space_low,
high=obs_space_high,
dtype=obs_space_type)
low=obs_space_low, high=obs_space_high, dtype=obs_space_type
)
self.share_observation_space = spaces.Box(
low=sobs_space_low,
high=sobs_space_high,
dtype=obs_space_type)
low=sobs_space_low, high=sobs_space_high, dtype=obs_space_type
)

def step(self, action, is_enemy=True):
# 传入action为0~8的数字
row_pos, col_pos = action // self.col, action % self.col
assert self.state[row_pos][col_pos] == 0, "({}, {}) pos has already be taken".format(row_pos, col_pos)
assert (
self.state[row_pos][col_pos] == 0
), "({}, {}) pos has already be taken".format(row_pos, col_pos)
self.state[row_pos][col_pos] = 2 if is_enemy else 1
done, have_winner = False, False

Expand All @@ -110,28 +126,28 @@ def step(self, action, is_enemy=True):
if done:
if have_winner:
reward = (-1) * self.reward if is_enemy else self.reward
winner = 'enemy' if is_enemy else 'self'
winner = "enemy" if is_enemy else "self"
else:
winner = 'tie'
winner = "tie"
reward = 0
else:
reward = 0
winner = 'no'
info = {'who_win': winner}
winner = "no"
info = {"who_win": winner}
return self.state.flatten().copy(), reward, done, False, info

def check_if_finish(self):
return (self.state == 0).sum() == 0

def reset(self, seed=None, options=None, set_who_first=None):
def reset(self, seed=None, options=None, set_who_first=None):
self.state = np.zeros([self.row, self.col]) # 0无棋子,1我方棋子,2敌方棋子
if set_who_first is not None:
who_first = set_who_first
else:
if np.random.random() > 0.5:
who_first = 'enemy'
who_first = "enemy"
else:
who_first = 'self'
who_first = "self"
obs = self.state.flatten().copy()
# return obs, {"who_first": who_first}
return obs, {}
Expand All @@ -146,13 +162,8 @@ def close(self):
pass


if __name__ == '__main__':
args = {
"row": 3,
"col": 3,
"num_to_win": 3,
"num_agents": 1
}
if __name__ == "__main__":
args = {"row": 3, "col": 3, "num_to_win": 3, "num_agents": 1}
env = Connect3Env(env_name="connect3", args=args)
obs, info = env.reset()
obs, reward, done, _, info = env.step(1, is_enemy=True)
Expand Down
82 changes: 75 additions & 7 deletions openrl/envs/vec_env/async_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from multiprocessing import Queue
from multiprocessing.connection import Connection
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import gymnasium as gym
import numpy as np
Expand All @@ -21,7 +21,7 @@
from gymnasium.vector.utils import CloudpickleWrapper, clear_mpi_env_vars
from numpy.typing import NDArray

from openrl.envs.vec_env.base_venv import BaseVecEnv
from openrl.envs.vec_env.base_venv import BaseVecEnv, VecEnvIndices
from openrl.envs.vec_env.utils.numpy_utils import (
concatenate,
create_empty_array,
Expand All @@ -32,6 +32,8 @@
read_from_shared_memory,
write_to_shared_memory,
)
from openrl.envs.wrappers.base_wrapper import BaseWrapper
from openrl.envs.wrappers.util import is_wrapped


class AsyncState(Enum):
Expand Down Expand Up @@ -608,19 +610,65 @@ def call_fetch(self, timeout: Union[int, float, None] = None) -> list:

return results

def call(self, name: str, *args, **kwargs) -> List[Any]:
"""Call a method, or get a property, from each parallel environment.
def exec_func_send(self, func: Callable, indices, *args, **kwargs):
"""Calls the method with name asynchronously and apply args and kwargs to the method.

Args:
name (str): Name of the method or property to call.
func: a function.
indices: Indices of the environments to call the method on.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.

Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: Calling `call_send` while waiting for a pending call to complete
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
str(self._state.value),
)

for pipe in self.parent_pipes:
pipe.send(("_func_exec", (func, indices, args, kwargs)))
self._state = AsyncState.WAITING_CALL

def exec_func_fetch(self, timeout: Union[int, float, None] = None) -> list:
"""Calls all parent pipes and waits for the results.

Args:
timeout: Number of seconds before the call to `step_fetch` times out.
If `None` (default), the call to `step_fetch` never times out.

Returns:
List of the results of the individual calls to the method or property for each environment.

Raises:
NoAsyncCallError: Calling `call_fetch` without any prior call to `call_send`.
TimeoutError: The call to `call_fetch` has timed out after timeout second(s).
"""
self.call_send(name, *args, **kwargs)
return self.call_fetch()
self._assert_is_running()
if self._state != AsyncState.WAITING_CALL:
raise NoAsyncCallError(
"Calling `exec_func_fetch` without any prior call to `exec_func_send`.",
AsyncState.WAITING_CALL.value,
)

if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
f"The call to `call_fetch` has timed out after {timeout} second(s)."
)

results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT

return results

def get_attr(self, name: str):
"""Get a property from each parallel environment.
Expand Down Expand Up @@ -670,6 +718,16 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)

# def env_is_wrapped(
# self, wrapper_class: Type[BaseWrapper], indices: VecEnvIndices = None
# ) -> List[bool]:
# """Check if worker environments are wrapped with a given wrapper"""
# indices = self._get_indices(indices)
# results = self.exec_func(
# is_wrapped, indices=indices, wrapper_class=wrapper_class
# )
# return [results[i] for i in indices]


def _worker(
index: int,
Expand Down Expand Up @@ -771,6 +829,15 @@ def prepare_obs(observation):
True,
)
)
elif command == "_func_exec":
function, indices, args, kwargs = data
if index in indices:
if callable(function):
pipe.send((function(env, *args, **kwargs), True))
else:
pipe.send((function, True))
else:
pipe.send((None, True))
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
Expand All @@ -783,6 +850,7 @@ def prepare_obs(observation):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))

elif command == "_setattr":
name, value = data
setattr(env, name, value)
Expand Down
Loading