Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

demo(nyz): slime volleyball league training #229

Merged
merged 7 commits into from
Mar 19, 2022
1 change: 0 additions & 1 deletion dizoo/slime_volley/config/__init__.py

This file was deleted.

57 changes: 23 additions & 34 deletions dizoo/slime_volley/envs/slime_volley_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def close(self) -> None:
self._env.close()
self._init_flag = False

def step(self, action: Union[np.ndarray, List[np.ndarray]]):
def step(self, action: Union[np.ndarray, List[np.ndarray]]) -> BaseEnvTimestep:
if self._agent_vs_agent:
assert isinstance(action, list) and isinstance(action[0], np.ndarray)
assert isinstance(action, List) and all([isinstance(e, np.ndarray) for e in action])
action1, action2 = action[0], action[1]
else:
assert isinstance(action, np.ndarray)
Expand Down Expand Up @@ -101,6 +101,9 @@ def reset(self):
self._env = GymSelfPlayMonitor(
self._env, self._replay_path, video_callable=lambda episode_id: True, force=True
)
self._observation_space = self._env.observation_space
self._action_space = gym.spaces.Discrete(6)
self._reward_space = gym.spaces.Box(low=-5, high=5, shape=(1, ), dtype=np.float32)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
Expand All @@ -116,38 +119,24 @@ def reset(self):
else:
return obs

def info(self):
T = EnvElementInfo
return BaseEnvInfo(
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
agent_num=2,
obs_space=T(
(2, 12) if self._agent_vs_agent else (12, ),
{
'min': [float("-inf") for _ in range(12)],
'max': [float("inf") for _ in range(12)],
'dtype': np.float32,
},
),
# [min, max)
# 6 valid actions:
act_space=T(
(1, ),
{
'min': 0,
'max': 6,
'dtype': int,
},
),
rew_space=T(
(1, ),
{
'min': -5.0,
'max': 5.0,
'dtype': np.float32,
},
),
use_wrappers=None,
)
@property
def observation_space(self) -> gym.spaces.Space:
return self._observation_space

@property
def action_space(self) -> gym.spaces.Space:
return self._action_space

@property
def reward_space(self) -> gym.spaces.Space:
return self._reward_space

def random_action(self) -> np.ndarray:
high = self.action_space.n
if self._agent_vs_agent:
return [np.random.randint(0, high, size=(1, )) for _ in range(2)]
else:
return np.random.randint(0, high, size=(1, ))

def __repr__(self):
return "DI-engine Slime Volley Env"
Expand Down
13 changes: 5 additions & 8 deletions dizoo/slime_volley/envs/test_slime_volley_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@ def test_slime_volley(self, agent_vs_agent):
env = SlimeVolleyEnv(EasyDict({'env_id': 'SlimeVolley-v0', 'agent_vs_agent': agent_vs_agent}))
# env.enable_save_replay('replay_video')
obs1 = env.reset()
done = False
print(env._env.observation_space)
print(env.observation_space)
print('observation is like:', obs1)
done = False
while not done:
action = env.random_action()
observations, rewards, done, infos = env.step(action)
if agent_vs_agent:
action1 = np.random.randint(0, 2, (1, ))
action2 = np.random.randint(0, 2, (1, ))
action = [action1, action2]
total_rew += rewards[0]
else:
action = np.random.randint(0, 2, (1, ))
observations, rewards, done, infos = env.step(action)
total_rew += rewards[0]
total_rew += rewards
obs1, obs2 = observations[0], observations[1]
assert obs1.shape == obs2.shape, (obs1.shape, obs2.shape)
if agent_vs_agent:
Expand Down