Skip to content

Commit

Permalink
Remap action to fit gym's action space (thu-ml#313)
Browse files Browse the repository at this point in the history
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
  • Loading branch information
ChenDRAG and Trinkle23897 authored Mar 21, 2021
1 parent 9f9c18a commit 626549a
Show file tree
Hide file tree
Showing 21 changed files with 145 additions and 77 deletions.
3 changes: 1 addition & 2 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ def test_sac_bipedal(args=get_args()):

policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))
Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def test_sac(args=get_args()):

policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
exploration_noise=OUNoise(0.0, args.noise_std))
exploration_noise=OUNoise(0.0, args.noise_std),
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
3 changes: 1 addition & 2 deletions examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,9 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
Expand Down
3 changes: 1 addition & 2 deletions examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ def test_sac(args=get_args()):

policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def test_td3(args=get_args()):

policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip, estimation_step=args.n_step)
noise_clip=args.noise_clip, estimation_step=args.n_step,
action_space=env.action_space)

# load a previous policy
if args.resume_path:
Expand Down
3 changes: 1 addition & 2 deletions test/continuous/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,10 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
5 changes: 2 additions & 3 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ def dist(*logits):
# dual_clip=args.dual_clip,
# dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip,
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
# if clip the action, ppo would not converge :)
gae_lambda=args.gae_lambda)
gae_lambda=args.gae_lambda,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
3 changes: 1 addition & 2 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ def test_sac_with_il(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
4 changes: 2 additions & 2 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=args.rew_norm,
estimation_step=args.n_step)
estimation_step=args.n_step,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_a2c_with_il(args=get_args()):
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm)
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def test_pg(args=get_args()):
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm)
reward_normalization=args.rew_norm,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
4 changes: 2 additions & 2 deletions test/discrete/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def test_ppo(args=get_args()):
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
action_range=None,
gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
value_clip=args.value_clip)
value_clip=args.value_clip,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
8 changes: 6 additions & 2 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def collect(
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)

# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids)
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)

self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn:
Expand Down Expand Up @@ -419,8 +421,10 @@ def collect(
_alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
whole_data[ready_env_ids] = self.data # lots of overhead

# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids)
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)

# change self.data here because ready_env_ids has changed
ready_env_ids = np.array([i["env_id"] for i in info])
Expand Down
42 changes: 40 additions & 2 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,20 @@ class BasePolicy(ABC, nn.Module):

def __init__(
self,
observation_space: gym.Space = None,
action_space: gym.Space = None
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
action_scaling: bool = False,
action_bound_method: str = "",
) -> None:
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self.agent_id = 0
self.updating = False
self.action_scaling = action_scaling
# can be one of ("clip", "tanh", ""), empty string means no bounding
assert action_bound_method in ("", "clip", "tanh")
self.action_bound_method = action_bound_method
self._compile()

def set_agent_id(self, agent_id: int) -> None:
Expand Down Expand Up @@ -114,6 +120,38 @@ def forward(
"""
pass

def map_action(self, act: Union[Batch, np.ndarray]) -> Union[Batch, np.ndarray]:
"""Map raw network output to action range in gym's env.action_space.
This function is called in :meth:`~tianshou.data.Collector.collect` and only
affects action sending to env. Remapped action will not be stored in buffer
and thus can be viewed as a part of env (a black box action transformation).
Action mapping includes 2 standard procedures: bounding and scaling. Bounding
procedure expects original action range is (-inf, inf) and maps it to [-1, 1],
while scaling procedure expects original action range is (-1, 1) and maps it
to [action_space.low, action_space.high]. Bounding procedure is applied first.
:param act: a data batch or numpy.ndarray which is the action taken by
policy.forward.
:return: action in the same form of input "act" but remap to the target action
space.
"""
if isinstance(self.action_space, gym.spaces.Box) and \
isinstance(act, np.ndarray):
# currently this action mapping only supports np.ndarray action
if self.action_bound_method == "clip":
act = np.clip(act, -1.0, 1.0)
elif self.action_bound_method == "tanh":
act = np.tanh(act)
if self.action_scaling:
assert np.all(act >= -1.0) and np.all(act <= 1.0), \
"action scaling only accepts raw action range = [-1, 1]"
low, high = self.action_space.low, self.action_space.high
act = low + (high - low) * (act + 1.0) / 2.0
return act

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
Expand Down
7 changes: 7 additions & 0 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class A2CPolicy(PGPolicy):
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint.
Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
Expand Down
21 changes: 11 additions & 10 deletions tianshou/policy/modelfree/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@ class DDPGPolicy(BasePolicy):
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic_optim: the optimizer for critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param BaseNoise exploration_noise: the exploration noise,
add to the action. Default to ``GaussianNoise(sigma=0.1)``.
:param bool reward_normalization: normalize the reward to Normal(0, 1),
Default to False.
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
Expand All @@ -38,15 +43,17 @@ def __init__(
actor_optim: Optional[torch.optim.Optimizer],
critic: Optional[torch.nn.Module],
critic_optim: Optional[torch.optim.Optimizer],
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
reward_normalization: bool = False,
estimation_step: int = 1,
action_scaling: bool = True,
action_bound_method: str = "clip",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor)
Expand All @@ -62,9 +69,6 @@ def __init__(
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
self._gamma = gamma
self._noise = exploration_noise
self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2.0
self._action_scale = (action_range[1] - action_range[0]) / 2.0
# it is only a little difference to use GaussianNoise
# self.noise = OUNoise()
self._rew_norm = reward_normalization
Expand Down Expand Up @@ -128,8 +132,6 @@ def forward(
model = getattr(self, model)
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)

@staticmethod
Expand Down Expand Up @@ -168,5 +170,4 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if self._noise:
act = act + self._noise(act.shape)
act = act.clip(self._range[0], self._range[1])
return act
8 changes: 4 additions & 4 deletions tianshou/policy/modelfree/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __init__(
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
reward_normalization, estimation_step,
**kwargs)
super().__init__(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
tau, gamma, alpha, reward_normalization, estimation_step,
action_scaling=False, action_bound_method="", **kwargs)
self._alpha: Union[float, torch.Tensor]

def forward( # type: ignore
Expand Down
12 changes: 11 additions & 1 deletion tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ class PGPolicy(BasePolicy):
:param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution]
:param float discount_factor: in [0, 1]. Default to 0.99.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
Expand All @@ -29,9 +36,12 @@ def __init__(
dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
reward_normalization: bool = False,
action_scaling: bool = True,
action_bound_method: str = "clip",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if model is not None:
self.model: torch.nn.Module = model
self.optim = optim
Expand Down
Loading

0 comments on commit 626549a

Please sign in to comment.