Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/update diambra #111

Merged
merged 3 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions howto/learn_in_diambra.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Each environment has its own observation and action space, so it is reccomended
>
> You have to be [registered](https://diambra.ai/register/) and logged in to acces the [DIAMRA documentation](https://docs.diambra.ai/).

The observation space is slightly modified to be compatible with our algorithms, in particular, the `gym.spaces.Box` observations are converted in `gymnasium.spaces.Box` observations, mantaining the dimensions, the range and the type of the observations. Moreover, the `gym.spaces.Discrete` observations are converted into `gymnasium.spaces.Box` observations with dimension `(1,)`, of type `int` and range from `0` to `n - 1`, where `n` is the number of options of the Discrete space. Finally, the `gym.spaces.MultiDiscrete` observations are converted into `gymnasium.spaces.Box` observations with dimension `(k,)` where `k` is the length of the MultiDiscrete space, of type `int` and range from `0` to `n[i] - 1` where `n[i]` is the number of options of the *i-th* element of the MultiDiscrete.
The observation space is slightly modified to be compatible with our algorithms, in particular, the `gymnasium.spaces.Discrete` observations are converted into `gymnasium.spaces.Box` observations with dimension `(1,)`, of type `int` and range from `0` to `n - 1`, where `n` is the number of options of the Discrete space. Finally, the `gymnasium.spaces.MultiDiscrete` observations are converted into `gymnasium.spaces.Box` observations with dimension `(k,)` where `k` is the length of the MultiDiscrete space, of type `int` and range from `0` to `n[i] - 1` where `n[i]` is the number of options of the *i-th* element of the MultiDiscrete.

> **Note**
>
Expand Down Expand Up @@ -86,7 +86,7 @@ env:
diambra_settings:
characters: Kasumi
step_ratio: 5
player: P1
role: diambra.arena.Roles.P1
diambra_wrappers:
reward_normalization: True
reward_normalization_factor: 0.3
Expand All @@ -102,14 +102,14 @@ diambra run -s=4 python sheeprl.py exp=custom_exp env.num_envs=4
> Some settings and wrappers are included in the cli arguments when the command is launched. These settings/wrappers cannot be specified in the `diambra_settings` and `diambra_wrappers` parameters, respectively.
> The settings/wrappers you cannot specify in the `diambra_settings` and `diambra_wrappers` parameters are the following:
> * `action_space` (settings): you can set it with the `env.wrapper.action_space` argument.
> * `attack_but_combination` (settings): you can set it with the `env.wrapper.attack_but_combination` argument.
> * `frame_shape` (settings): you can set it with the `env.screen_size` argument.
> * `n_players` (settings): you cannot set it, since it is always `1`.
> * `frame_shape` (settings and wrappers): you can set it with the `env.screen_size` argument.
> * `flatten` (wrappers): you cannot set it, since it is always `True`.
> * `sticky_actions` (wrappers): you can set it with the `env.action_repeat` argument.
> * `frame_stack` (wrappers): you can set it with the `env.frame_stack` argument.
> * `repeat_action` (wrappers): you can set it with the `env.action_repeat` argument.
> * `stack_frames` (wrappers): you can set it with the `env.stack_frames` argument.
> * `dilation` (wrappers): you can set it with the `env.frame_stack_dilation` argument
>
> When you set the `action_repeat` cli argument greater than one (i.e., the `sticky_actions` DIAMBRA wrapper), the `step_ratio` diambra setting is automatically modified to $1$ because it is a DIAMBRA requirement.
> When you set the `action_repeat` cli argument greater than one (i.e., the `repeat_action` DIAMBRA wrapper), the `step_ratio` diambra setting is automatically modified to $1$ because it is a DIAMBRA requirement.
>
> **Important**
>
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ atari = [
]
minedojo = ["minedojo==0.1", "importlib_resources==5.12.0"]
minerl = ["minerl==0.4.4"]
diambra = ["wheel==0.38.4", "setuptools<=66.0.0", "gym==0.21.0", "diambra==0.0.16", "diambra-arena==2.1.2"]
diambra = ["diambra==0.0.16", "diambra-arena==2.2.1"]
crafter = ["crafter==1.8.1"]

[tool.ruff]
Expand Down
24 changes: 17 additions & 7 deletions sheeprl/configs/env/diambra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,25 @@ action_repeat: 1
wrapper:
_target_: sheeprl.envs.diambra.DiambraWrapper
id: ${env.id}
action_space: discrete
action_space: diambra.arena.SpaceTypes.DISCRETE # or diambra.arena.SpaceTypes.MULTI_DISCRETE
screen_size: ${env.screen_size}
grayscale: ${env.grayscale}
attack_but_combination: False
sticky_actions: ${env.action_repeat}
seed: null
repeat_action: ${env.action_repeat}
rank: null
log_level: 0
diambra_settings:
player: P1
role: diambra.arena.Roles.P1
step_ratio: 6
difficulty: 4
continue_game: 0.0
show_final: False
outfits: 1
diambra_wrappers:
actions_stack: 12
noop_max: 0
stack_actions: 1
no_op_max: 0
no_attack_buttons_combinations: False
add_last_action: True
scale: False
exclude_image_scaling: False
process_discrete_binary: False
role_relative: True
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ env:
screen_size: 128
reward_as_observation: True
wrapper:
attack_but_combination: True
diambra_settings:
characters: Kasumi
role: null
characters: null
difficulty: 4
diambra_wrappers:
no_attack_buttons_combinations: False

# Checkpoint
checkpoint:
Expand All @@ -38,30 +40,30 @@ cnn_keys:
- frame
mlp_keys:
encoder:
- reward
- P1_actions_attack
- P1_actions_move
- P1_oppChar
- P1_oppHealth
- P1_oppSide
- P1_oppWins
- P1_ownChar
- P1_ownHealth
- P1_ownSide
- P1_ownWins
- own_character
- own_health
- own_side
- own_wins
- opp_character
- opp_health
- opp_side
- opp_wins
- stage
- timer
- action
- reward
decoder:
- P1_actions_attack
- P1_actions_move
- P1_oppChar
- P1_oppHealth
- P1_oppSide
- P1_oppWins
- P1_ownChar
- P1_ownHealth
- P1_ownSide
- P1_ownWins
- own_character
- own_health
- own_side
- own_wins
- opp_character
- opp_health
- opp_side
- opp_wins
- stage
- timer
- action

# Algorithm
algo:
Expand Down
90 changes: 50 additions & 40 deletions sheeprl/envs/diambra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import diambra
import diambra.arena
import gym
import gymnasium
import gymnasium as gym
import numpy as np
from diambra.arena import EnvironmentSettings, WrappersSettings
from gymnasium import core
from gymnasium.core import RenderFrame

Expand All @@ -25,12 +25,12 @@ def __init__(
action_space: str = "discrete",
screen_size: Union[int, Tuple[int, int]] = 64,
grayscale: bool = False,
attack_but_combination: bool = True,
sticky_actions: int = 1,
seed: Optional[int] = None,
repeat_action: int = 1,
rank: int = 0,
diambra_settings: Dict[str, Any] = {},
diambra_wrappers: Dict[str, Any] = {},
render_mode: str = "rgb_array",
log_level: int = 0,
) -> None:
super().__init__()

Expand All @@ -39,45 +39,49 @@ def __init__(

if diambra_settings.pop("frame_shape", None) is not None:
warnings.warn("The DIAMBRA frame_shape setting is disabled")
settings = {
if diambra_settings.pop("n_players", None) is not None:
warnings.warn("The DIAMBRA n_players setting is disabled")

role = diambra_settings.pop("role", None)
settings = EnvironmentSettings(
**diambra_settings,
"action_space": action_space.lower(),
"attack_but_combination": attack_but_combination,
}
if sticky_actions > 1:
**{
"game_id": id,
"action_space": eval(action_space),
"n_players": 1,
"role": eval(role) if role is not None else None,
"render_mode": render_mode,
},
)
if repeat_action > 1:
if "step_ratio" not in settings or settings["step_ratio"] > 1:
warnings.warn(
f"step_ratio parameter modified to 1 because the sticky action is active ({sticky_actions})"
f"step_ratio parameter modified to 1 because the sticky action is active ({repeat_action})"
)
settings["step_ratio"] = 1
if diambra_wrappers.pop("hwc_obs_resize", None) is not None:
warnings.warn("The DIAMBRA hwc_obs_resize wrapper is disabled")
if diambra_wrappers.pop("frame_stack", None) is not None:
warnings.warn("The DIAMBRA frame_stack wrapper is disabled")
if diambra_wrappers.pop("frame_shape", None) is not None:
warnings.warn("The DIAMBRA frame_shape wrapper is disabled")
if diambra_wrappers.pop("stack_frames", None) is not None:
warnings.warn("The DIAMBRA stack_frames wrapper is disabled")
if diambra_wrappers.pop("dilation", None) is not None:
warnings.warn("The DIAMBRA dilation wrapper is disabled")
wrappers = {
if diambra_wrappers.pop("flatten", None) is not None:
warnings.warn("The DIAMBRA flatten wrapper is disabled")
wrappers = WrappersSettings(
**diambra_wrappers,
"flatten": True,
"sticky_actions": sticky_actions,
"hwc_obs_resize": screen_size + (1 if grayscale else 3,),
}
self._env = diambra.arena.make(id, settings, wrappers, seed=seed, rank=rank)
**{
"flatten": True,
"repeat_action": repeat_action,
"frame_shape": screen_size + (int(grayscale),),
},
)
self._env = diambra.arena.make(id, settings, wrappers, rank=rank, render_mode=render_mode, log_level=log_level)

# Observation and action space
self.action_space = (
gymnasium.spaces.Discrete(self._env.action_space.n)
if action_space.lower() == "discrete"
else gymnasium.spaces.MultiDiscrete(self._env.action_space.nvec)
)
self.action_space = self._env.action_space
obs = {}
for k in self._env.observation_space.spaces.keys():
if isinstance(self._env.observation_space[k], gym.spaces.Box):
low = self._env.observation_space[k].low
high = self._env.observation_space[k].high
shape = self._env.observation_space[k].shape
dtype = self._env.observation_space[k].dtype
elif isinstance(self._env.observation_space[k], gym.spaces.Discrete):
if isinstance(self._env.observation_space[k], gym.spaces.Discrete):
low = 0
high = self._env.observation_space[k].n - 1
shape = (1,)
Expand All @@ -87,11 +91,15 @@ def __init__(
high = self._env.observation_space[k].nvec - 1
shape = (len(high),)
dtype = np.int32
else:
elif not isinstance(self._env.observation_space[k], gym.spaces.Box):
raise RuntimeError(f"Invalid observation space, got: {type(self._env.observation_space[k])}")
obs[k] = gymnasium.spaces.Box(low, high, shape, dtype)
self.observation_space = gymnasium.spaces.Dict(obs)
self.render_mode = "rgb_array"
obs[k] = (
self._env.observation_space[k]
if isinstance(self._env.observation_space[k], gym.spaces.Box)
else gym.spaces.Box(low, high, shape, dtype)
)
self.observation_space = gym.spaces.Dict(obs)
self.render_mode = render_mode

def __getattr__(self, name):
return getattr(self._env, name)
Expand All @@ -103,17 +111,19 @@ def _convert_obs(self, obs: Dict[str, Union[int, np.ndarray]]) -> Dict[str, np.n
}

def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, reward, done, infos = self._env.step(action)
obs, reward, done, truncated, infos = self._env.step(action)
infos["env_domain"] = "DIAMBRA"
return self._convert_obs(obs), reward, done or infos.get("env_done", False), False, infos
return self._convert_obs(obs), reward, done or infos.get("env_done", False), truncated, infos

def render(self, mode: str = "rgb_array", **kwargs) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
return self._env.render("rgb_array")
return self._env.render()

def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[Any, Dict[str, Any]]:
return self._convert_obs(self._env.reset()), {"env_domain": "DIAMBRA"}
obs, infos = self._env.reset(seed=seed, options=options)
infos["env_domain"] = "DIAMBRA"
return self._convert_obs(obs), infos

def close(self) -> None:
self._env.close()
Expand Down