diff --git a/alf/algorithms/data_transformer.py b/alf/algorithms/data_transformer.py index 2fa6c8137..24d4ae772 100644 --- a/alf/algorithms/data_transformer.py +++ b/alf/algorithms/data_transformer.py @@ -736,14 +736,25 @@ class HindsightExperienceTransformer(DataTransformer): of the current timestep. The exact field names can be provided via arguments to the class ``__init__``. + NOTE: The HindsightExperienceTransformer has to happen before any transformer which changes + reward or achieved_goal fields, e.g. observation normalizer, reward clipper, etc.. + See `documentation <../../docs/notes/knowledge_base.rst#datatransformers>`_ for details. + To use this class, add it to any existing data transformers, e.g. use this config if ``ObservationNormalizer`` is an existing data transformer: .. code-block:: python - ReplayBuffer.keep_episodic_info=True - HindsightExperienceTransformer.her_proportion=0.8 - TrainerConfig.data_transformer_ctor=[@HindsightExperienceTransformer, @ObservationNormalizer] + alf.config('ReplayBuffer', keep_episodic_info=True) + alf.config( + 'HindsightExperienceTransformer', + her_proportion=0.8 + ) + alf.config( + 'TrainerConfig', + data_transformer_ctor=[ + HindsightExperienceTransformer, ObservationNormalizer + ]) See unit test for more details on behavior. """ @@ -820,9 +831,10 @@ def transform_experience(self, experience: Experience): # relabel only these sampled indices her_cond = torch.rand(batch_size) < her_proportion (her_indices, ) = torch.where(her_cond) + has_her = torch.any(her_cond) - last_step_pos = start_pos[her_indices] + batch_length - 1 - last_env_ids = env_ids[her_indices] + last_step_pos = start_pos + batch_length - 1 + last_env_ids = env_ids # Get x, y indices of LAST steps dist = buffer.steps_to_episode_end(last_step_pos, last_env_ids) if alf.summary.should_record_summaries(): @@ -831,22 +843,24 @@ def transform_experience(self, experience: Experience): torch.mean(dist.type(torch.float32))) # get random future state - future_idx = last_step_pos + (torch.rand(*dist.shape) * - (dist + 1)).to(torch.int64) + future_dist = (torch.rand(*dist.shape) * (dist + 1)).to( + torch.int64) + future_idx = last_step_pos + future_dist future_ag = buffer.get_field(self._achieved_goal_field, last_env_ids, future_idx).unsqueeze(1) # relabel desired goal result_desired_goal = alf.nest.get_field(result, self._desired_goal_field) - relabed_goal = result_desired_goal.clone() + relabeled_goal = result_desired_goal.clone() her_batch_index_tuple = (her_indices.unsqueeze(1), torch.arange(batch_length).unsqueeze(0)) - relabed_goal[her_batch_index_tuple] = future_ag + if has_her: + relabeled_goal[her_batch_index_tuple] = future_ag[her_indices] # recompute rewards result_ag = alf.nest.get_field(result, self._achieved_goal_field) - relabeled_rewards = self._reward_fn(result_ag, relabed_goal) + relabeled_rewards = self._reward_fn(result_ag, relabeled_goal) non_her_or_fst = ~her_cond.unsqueeze(1) & (result.step_type != StepType.FIRST) @@ -876,21 +890,28 @@ def transform_experience(self, experience: Experience): alf.summary.scalar( "replayer/" + buffer._name + ".reward_mean_before_relabel", torch.mean(result.reward[her_indices][:-1])) - alf.summary.scalar( - "replayer/" + buffer._name + ".reward_mean_after_relabel", - torch.mean(relabeled_rewards[her_indices][:-1])) + if has_her: + alf.summary.scalar( + "replayer/" + buffer._name + ".reward_mean_after_relabel", + torch.mean(relabeled_rewards[her_indices][:-1])) + alf.summary.scalar("replayer/" + buffer._name + ".future_distance", + torch.mean(future_dist.float())) result = alf.nest.transform_nest( - result, self._desired_goal_field, lambda _: relabed_goal) - + result, self._desired_goal_field, lambda _: relabeled_goal) result = result.update_time_step_field('reward', relabeled_rewards) - + info = info._replace(her=her_cond, future_distance=future_dist) if alf.get_default_device() != buffer.device: for f in accessed_fields: result = alf.nest.transform_nest( result, f, lambda t: convert_device(t)) - result = alf.nest.transform_nest( - result, "batch_info.replay_buffer", lambda _: buffer) + info = convert_device(info) + info = info._replace( + her=info.her.unsqueeze(1).expand(result.reward.shape[:2]), + future_distance=info.future_distance.unsqueeze(1).expand( + result.reward.shape[:2]), + replay_buffer=buffer) + result = alf.data_structures.add_batch_info(result, info) return result diff --git a/alf/algorithms/ddpg_algorithm.py b/alf/algorithms/ddpg_algorithm.py index 7c0678998..d1c5d6fd7 100644 --- a/alf/algorithms/ddpg_algorithm.py +++ b/alf/algorithms/ddpg_algorithm.py @@ -40,9 +40,20 @@ DdpgActorState = namedtuple("DdpgActorState", ['actor', 'critics']) DdpgState = namedtuple("DdpgState", ['actor', 'critics']) DdpgInfo = namedtuple( - "DdpgInfo", [ - "reward", "step_type", "discount", "action", "action_distribution", - "actor_loss", "critic", "discounted_return" + "DdpgInfo", + [ + "reward", + "step_type", + "discount", + "action", + "action_distribution", + "actor_loss", + "critic", + # Optional fields for value target lower bounding or Hindsight relabeling. + # TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER. + "discounted_return", + "future_distance", + "her" ], default_value=()) DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic')) diff --git a/alf/algorithms/one_step_loss.py b/alf/algorithms/one_step_loss.py index 34ee329ad..e687f9dc5 100644 --- a/alf/algorithms/one_step_loss.py +++ b/alf/algorithms/one_step_loss.py @@ -16,12 +16,12 @@ from typing import Union, List, Callable import alf -from alf.algorithms.td_loss import TDLoss, TDQRLoss +from alf.algorithms.td_loss import LowerBoundedTDLoss, TDQRLoss from alf.utils import losses @alf.configurable -class OneStepTDLoss(TDLoss): +class OneStepTDLoss(LowerBoundedTDLoss): def __init__(self, gamma: Union[float, List[float]] = 0.99, td_error_loss_fn: Callable = losses.element_wise_squared_loss, diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index 8235d0834..2a4c27e3b 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -54,9 +54,22 @@ "SacActorInfo", ["actor_loss", "neg_entropy"], default_value=()) SacInfo = namedtuple( - "SacInfo", [ - "reward", "step_type", "discount", "action", "action_distribution", - "actor", "critic", "alpha", "log_pi", "discounted_return" + "SacInfo", + [ + "reward", + "step_type", + "discount", + "action", + "action_distribution", + "actor", + "critic", + "alpha", + "log_pi", + # Optional fields for value target lower bounding or Hindsight relabeling. + # TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER. + "discounted_return", + "future_distance", + "her" ], default_value=()) diff --git a/alf/algorithms/td_loss.py b/alf/algorithms/td_loss.py index 80c2c0a93..6f4e27e65 100644 --- a/alf/algorithms/td_loss.py +++ b/alf/algorithms/td_loss.py @@ -106,7 +106,11 @@ def gamma(self): """ return self._gamma.clone() - def compute_td_target(self, info: namedtuple, target_value: torch.Tensor): + def compute_td_target(self, + info: namedtuple, + value: torch.Tensor, + target_value: torch.Tensor, + qr: bool = False): """Calculate the td target. The first dimension of all the tensors is time dimension and the second @@ -119,46 +123,48 @@ def compute_td_target(self, info: namedtuple, target_value: torch.Tensor): - reward: - step_type: - discount: + value (torch.Tensor): the time-major tensor for the value at + each time step. Some of its value can be overwritten and passed + back to the caller. target_value (torch.Tensor): the time-major tensor for the value at each time step. This is used to calculate return. ``target_value`` can be same as ``value``. Returns: - td_target + td_target, updated value, optional constraint_loss """ + if not qr and info.reward.ndim == 3: + # Multi-dim reward, not quantile regression. + # [T, B, D] or [T, B, 1] + discounts = info.discount.unsqueeze(-1) * self._gamma + else: + # [T, B] + discounts = info.discount * self._gamma + if self._lambda == 1.0: returns = value_ops.discounted_return( rewards=info.reward, values=target_value, step_types=info.step_type, - discounts=info.discount * self._gamma) + discounts=discounts) elif self._lambda == 0.0: returns = value_ops.one_step_discounted_return( rewards=info.reward, values=target_value, step_types=info.step_type, - discounts=info.discount * self._gamma) + discounts=discounts) else: advantages = value_ops.generalized_advantage_estimation( rewards=info.reward, values=target_value, step_types=info.step_type, - discounts=info.discount * self._gamma, + discounts=discounts, td_lambda=self._lambda) returns = advantages + target_value[:-1] - disc_ret = () - if hasattr(info, "discounted_return"): - disc_ret = info.discounted_return - if disc_ret != (): - with alf.summary.scope(self._name): - episode_ended = disc_ret > self._default_return - alf.summary.scalar("episodic_discounted_return_all", - torch.mean(disc_ret[episode_ended])) - alf.summary.scalar( - "value_episode_ended_all", - torch.mean(value[:-1][:, episode_ended[0, :]])) + returns = advantages + value[:-1] + returns = returns.detach() - return returns + return returns, value, None def forward(self, info: namedtuple, value: torch.Tensor, target_value: torch.Tensor): @@ -182,7 +188,8 @@ def forward(self, info: namedtuple, value: torch.Tensor, Returns: LossInfo: with the ``extra`` field same as ``loss``. """ - returns = self.compute_td_target(info, target_value) + returns, value, constraint_loss = self.compute_td_target( + info, value, target_value) value = value[:-1] if self._normalize_target: @@ -230,6 +237,221 @@ def _summarize(v, r, td, suffix): return LossInfo(loss=loss, extra=loss) +@alf.configurable +class LowerBoundedTDLoss(TDLoss): + """Temporal difference loss with value target lower bounding.""" + + def __init__(self, + gamma: Union[float, List[float]] = 0.99, + td_error_loss_fn: Callable = element_wise_squared_loss, + td_lambda: float = 0.95, + normalize_target: bool = False, + lb_target_q: float = 0., + default_return: float = -1000., + improve_w_goal_return: bool = False, + improve_w_nstep_bootstrap: bool = False, + improve_w_nstep_only: bool = False, + reward_multiplier: float = 1., + positive_reward: bool = True, + debug_summaries: bool = False, + name: str = "LbTDLoss"): + r""" + Args: + gamma .. use_retrace: pass through to TDLoss. + lb_target_q: between 0 and 1. When not zero, use this mixing rate for the + lower bounded value target. Only supports batch_length == 2, one step td. + default_return: Keep it the same as replay_buffer.default_return to plot to + tensorboard episodic_discounted_return only for the timesteps whose + episode already ended. + improve_w_goal_return: Use return calculated from the distance to hindsight + goals. Only supports batch_length == 2, one step td. + improve_w_nstep_bootstrap: Look ahead 2 to n steps, and take the largest + bootstrapped return to lower bound the value target of the 1st step. + improve_w_nstep_only: Only use the n-th step bootstrapped return as + value target lower bound. + reward_multiplier: Weight on the hindsight goal return. + positive_reward: If True, assumes 0/1 goal reward, otherwise, -1/0. + debug_summaries: True if debug summaries should be created. + name: The name of this loss. + """ + super().__init__( + gamma=gamma, + td_error_loss_fn=td_error_loss_fn, + td_lambda=td_lambda, + normalize_target=normalize_target, + name=name, + debug_summaries=debug_summaries) + + self._lb_target_q = lb_target_q + self._default_return = default_return + self._improve_w_goal_return = improve_w_goal_return + self._improve_w_nstep_bootstrap = improve_w_nstep_bootstrap + self._improve_w_nstep_only = improve_w_nstep_only + self._reward_multiplier = reward_multiplier + self._positive_reward = positive_reward + + def compute_td_target(self, + info: namedtuple, + value: torch.Tensor, + target_value: torch.Tensor, + qr: bool = False): + """Calculate the td target. + + The first dimension of all the tensors is time dimension and the second + dimesion is the batch dimension. + + Args: + info (namedtuple): experience collected from ``unroll()`` or + a replay buffer. All tensors are time-major. ``info`` should + contain the following fields: + - reward: + - step_type: + - discount: + value (torch.Tensor): the time-major tensor for the value at + each time step. Some of its value can be overwritten and passed + back to the caller. + target_value (torch.Tensor): the time-major tensor for the value at + each time step. This is used to calculate return. ``target_value`` + can be same as ``value``, except for Retrace. + Returns: + td_target, updated value, optional constraint_loss + """ + returns, value, _ = super().compute_td_target(info, value, + target_value, qr) + + constraint_loss = None + if self._improve_w_nstep_bootstrap: + assert self._lambda == 1.0, "td lambda does not work with this" + future_returns = value_ops.first_step_future_discounted_returns( + rewards=info.reward, + values=target_value, + step_types=info.step_type, + discounts=discounts) + returns = value_ops.one_step_discounted_return( + rewards=info.reward, + values=target_value, + step_types=info.step_type, + discounts=discounts) + assert torch.all((returns[0] == future_returns[0]) | ( + info.step_type[0] == alf.data_structures.StepType.LAST)), \ + str(returns[0]) + " ne\n" + str(future_returns[0]) + \ + '\nrwd: ' + str(info.reward[0:2]) + \ + '\nlast: ' + str(info.step_type[0:2]) + \ + '\ndisct: ' + str(discounts[0:2]) + \ + '\nv: ' + str(target_value[0:2]) + if self._improve_w_nstep_only: + future_returns = future_returns[ + -1] # last is the n-step return + else: + future_returns = torch.max(future_returns, dim=0)[0] + + with alf.summary.scope(self._name): + alf.summary.scalar( + "max_1_to_n_future_return_gt_td", + torch.mean((returns[0] < future_returns).float())) + alf.summary.scalar("first_step_discounted_return", + torch.mean(returns[0])) + + returns[0] = torch.max(future_returns, returns[0]).detach() + returns[1:] = 0 + value = value.clone() + value[1:] = 0 + + disc_ret = () + if hasattr(info, "discounted_return"): + disc_ret = info.discounted_return + if disc_ret != (): + with alf.summary.scope(self._name): + episode_ended = disc_ret > self._default_return + alf.summary.scalar("episodic_discounted_return_all", + torch.mean(disc_ret[episode_ended])) + alf.summary.scalar( + "value_episode_ended_all", + torch.mean(value[:-1][:, episode_ended[0, :]])) + + if self._lb_target_q > 0 and disc_ret != (): + her_cond = info.her + mask = torch.ones(returns.shape, dtype=torch.bool) + if her_cond != () and torch.any(~her_cond): + mask = ~her_cond[:-1] + disc_ret = disc_ret[ + 1:] # it's expanded in ddpg_algorithm, need to revert back. + assert returns.shape == disc_ret.shape, "%s %s" % (returns.shape, + disc_ret.shape) + with alf.summary.scope(self._name): + alf.summary.scalar( + "episodic_return_gt_td", + torch.mean((returns < disc_ret).float()[mask])) + alf.summary.scalar( + "episodic_discounted_return", + torch.mean( + disc_ret[mask & (disc_ret > self._default_return)])) + returns[mask] = (1 - self._lb_target_q) * returns[mask] + \ + self._lb_target_q * torch.max(returns, disc_ret)[mask] + + if self._improve_w_goal_return: + batch_length, batch_size = returns.shape[:2] + her_cond = info.her + if her_cond != () and torch.any(her_cond): + dist = info.future_distance + if self._positive_reward: + goal_return = torch.pow( + self._gamma * torch.ones(her_cond.shape), dist) + else: + goal_return = -(1. - torch.pow(self._gamma, dist)) / ( + 1. - self._gamma) + goal_return *= self._reward_multiplier + goal_return = goal_return[:-1] + returns_0 = returns + # Multi-dim reward: + if len(returns.shape) > 2: + returns_0 = returns[:, :, 0] + returns_0 = torch.where(her_cond[:-1], + torch.max(returns_0, goal_return), + returns_0) + with alf.summary.scope(self._name): + alf.summary.scalar( + "goal_return_gt_td", + torch.mean((returns_0 < goal_return).float())) + alf.summary.scalar("goal_return", torch.mean(goal_return)) + if len(returns.shape) > 2: + returns[:, :, 0] = returns_0 + else: + returns = returns_0 + + return returns, value, constraint_loss + + def forward(self, info: namedtuple, value: torch.Tensor, + target_value: torch.Tensor): + """Calculate the loss. + + The first dimension of all the tensors is time dimension and the second + dimesion is the batch dimension. + + Args: + info: experience collected from ``unroll()`` or + a replay buffer. All tensors are time-major. ``info`` should + contain the following fields: + - reward: + - step_type: + - discount: + value: the time-major tensor for the value at each time + step. The loss is between this and the calculated return. + target_value: the time-major tensor for the value at + each time step. This is used to calculate return. ``target_value`` + can be same as ``value``. + Returns: + LossInfo: with the ``extra`` field same as ``loss``. + """ + loss_info = super().forward(info, value, target_value) + loss = loss_info.loss + if self._improve_w_nstep_bootstrap: + # Ignore 2nd to n-th step losses. + loss[1:] = 0 + + return LossInfo(loss=loss, extra=loss) + + @alf.configurable class TDQRLoss(TDLoss): """Temporal difference quantile regression loss. @@ -301,7 +523,8 @@ def forward(self, info: namedtuple, value: torch.Tensor, assert target_value.shape[-1] == self._num_quantiles, ( "The input target_value should have same num_quantiles as pre-defiend." ) - returns = self.compute_td_target(info, target_value) + returns, value, constraint_loss = self.compute_td_target( + info, value, target_value, qr=True) value = value[:-1] # for quantile regression TD, the value and target both have shape diff --git a/alf/algorithms/td_loss_test.py b/alf/algorithms/td_loss_test.py new file mode 100644 index 000000000..2458fb89e --- /dev/null +++ b/alf/algorithms/td_loss_test.py @@ -0,0 +1,65 @@ +# Copyright (c) 2019 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +import alf +from alf.algorithms.td_loss import LowerBoundedTDLoss +from alf.data_structures import TimeStep, StepType, namedtuple + +DataItem = namedtuple( + "DataItem", ["reward", "step_type", "discount"], default_value=()) + + +class LowerBoundedTDLossTest(unittest.TestCase): + """Tests for alf.algorithms.td_loss.LowerBoundedTDLoss + """ + + def _check(self, res, expected): + np.testing.assert_array_almost_equal(res, expected) + + def test_compute_td_target_nstep_bootstrap_lowerbound(self): + loss = LowerBoundedTDLoss( + gamma=1., improve_w_nstep_bootstrap=True, td_lambda=1) + # Tensors are transposed to be time_major [T, B, ...] + step_types = torch.tensor([[StepType.MID] * 5], + dtype=torch.int64).transpose(0, 1) + rewards = torch.tensor([[2.] * 5], dtype=torch.float32).transpose(0, 1) + discounts = torch.tensor([[0.9] * 5], dtype=torch.float32).transpose( + 0, 1) + values = torch.tensor([[1.] * 5], dtype=torch.float32).transpose(0, 1) + info = DataItem( + reward=rewards, step_type=step_types, discount=discounts) + returns, value, _ = loss.compute_td_target(info, values, values) + expected_return = torch.tensor( + [[2 + 0.9 * (2 + 0.9 * (2 + 0.9 * (2 + 0.9))), 0, 0, 0]], + dtype=torch.float32).transpose(0, 1) + self._check(res=returns, expected=expected_return) + + expected_value = torch.tensor([[1, 0, 0, 0, 0]], + dtype=torch.float32).transpose(0, 1) + self._check(res=value, expected=expected_value) + + # n-step return is below 1-step + values[2:] = -10 + expected_return[0] = 2 + 0.9 + returns, value, _ = loss.compute_td_target(info, values, values) + self._check(res=returns, expected=expected_return) + + +if __name__ == '__main__': + alf.test.main() diff --git a/alf/examples/sac_breakout_conf-lbtq-Qbert.png b/alf/examples/sac_breakout_conf-lbtq-Qbert.png new file mode 100644 index 000000000..f839c95cb Binary files /dev/null and b/alf/examples/sac_breakout_conf-lbtq-Qbert.png differ diff --git a/alf/examples/sac_breakout_conf.py b/alf/examples/sac_breakout_conf.py index e6b163393..8c754d56e 100644 --- a/alf/examples/sac_breakout_conf.py +++ b/alf/examples/sac_breakout_conf.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +# NOTE: to use this on a different atari game, add this flag: +# --conf_param='create_environment.env_name="QbertNoFrameskip-v4"' + +# NOTE: for lower bound value target improvement, add these flags: +# --conf_param='ReplayBuffer.keep_episodic_info=True' +# --conf_param='ReplayBuffer.record_episodic_return=True' +# --conf_param='LowerBoundedTDLoss.lb_target_q=True' + import functools import alf -from alf.algorithms.td_loss import TDLoss +from alf.algorithms.td_loss import LowerBoundedTDLoss from alf.environments.alf_wrappers import AtariTerminalOnLifeLossWrapper from alf.networks import QNetwork from alf.optimizers import AdamTF @@ -42,7 +50,7 @@ def define_config(name, default_value): fc_layer_params=FC_LAYER_PARAMS, conv_layer_params=CONV_LAYER_PARAMS) -critic_loss_ctor = functools.partial(TDLoss, td_lambda=0.95) +critic_loss_ctor = functools.partial(LowerBoundedTDLoss, td_lambda=0.95) lr = define_config('lr', 5e-4) critic_optimizer = AdamTF(lr=lr) @@ -82,7 +90,8 @@ def define_config(name, default_value): num_env_steps=12000000, evaluate=True, num_eval_episodes=100, - num_evals=10, + num_evals=50, + num_eval_environments=20, num_checkpoints=5, num_summaries=100, debug_summaries=True, diff --git a/alf/experience_replayers/replay_buffer.py b/alf/experience_replayers/replay_buffer.py index 6e95fe334..cffb47a22 100644 --- a/alf/experience_replayers/replay_buffer.py +++ b/alf/experience_replayers/replay_buffer.py @@ -29,12 +29,20 @@ from .segment_tree import SumSegmentTree, MaxSegmentTree +# her (Tensor): of shape (batch_size, batch_length) indicating which transitions are relabeled +# with hindsight. +# future_distance (Tensor): of shape (batch_size, batch_length), is the distance from +# the transition's end state to the sampled future state in terms of number of +# environment steps. future_distance[:, 0] == future_distance[:, n], so only the first step +# is accurate. BatchInfo = namedtuple( "BatchInfo", [ "env_ids", "positions", "importance_weights", "replay_buffer", + "her", + "future_distance", "discounted_return", ], default_value=()) diff --git a/alf/utils/value_ops.py b/alf/utils/value_ops.py index 8c36deff4..c99eca225 100644 --- a/alf/utils/value_ops.py +++ b/alf/utils/value_ops.py @@ -118,6 +118,72 @@ def action_importance_ratio(action_distribution, return importance_ratio, importance_ratio_clipped +def generalized_advantage_estimation(rewards, + values, + step_types, + discounts, + td_lambda=1.0, + time_major=True): + """Computes generalized advantage estimation (GAE) for the first T-1 steps. + + For theory, see + "High-Dimensional Continuous Control Using Generalized Advantage Estimation" + by John Schulman, Philipp Moritz et al. + See https://arxiv.org/abs/1506.02438 for full paper. + + The difference between this function and the one tf_agents.utils.value_ops + is that the accumulated_td is reset to 0 for is_last steps in this function. + + Define abbreviations: + - B: batch size representing number of trajectories + - T: number of steps per trajectory + + Args: + rewards (Tensor): shape is [T, B] (or [T]) representing rewards. + values (Tensor): shape is [T,B] (or [T]) representing values. + step_types (Tensor): shape is [T,B] (or [T]) representing step types. + discounts (Tensor): shape is [T, B] (or [T]) representing discounts. + td_lambda (float): A scalar between [0, 1]. It's used for variance + reduction in temporal difference. + time_major (bool): Whether input tensors are time major. + False means input tensors have shape [B, T]. + + Returns: + A tensor with shape [T-1, B] representing advantages. Shape is [B, T-1] + when time_major is false. + """ + + if not time_major: + discounts = discounts.transpose(0, 1) + rewards = rewards.transpose(0, 1) + values = values.transpose(0, 1) + step_types = step_types.transpose(0, 1) + + assert values.shape[0] >= 2, ("The sequence length needs to be " + "at least 2. Got {s}".format( + s=values.shape[0])) + + is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32) + is_lasts = common.expand_dims_as(is_lasts, values) + discounts = common.expand_dims_as(discounts, values) + + weighted_discounts = discounts[1:] * td_lambda + + advs = torch.zeros_like(values) + delta = rewards[1:] + discounts[1:] * values[1:] - values[:-1] + + with torch.no_grad(): + for t in reversed(range(rewards.shape[0] - 1)): + advs[t] = (1 - is_lasts[t]) * \ + (delta[t] + weighted_discounts[t] * advs[t + 1]) + advs = advs[:-1] + + if not time_major: + advs = advs.transpose(0, 1) + + return advs.detach() + + def discounted_return(rewards, values, step_types, discounts, time_major=True): """Computes discounted return for the first T-1 steps. @@ -180,24 +246,36 @@ def discounted_return(rewards, values, step_types, discounts, time_major=True): return rets.detach() -def one_step_discounted_return(rewards, values, step_types, discounts): - """Calculate the one step discounted return for the first T-1 steps. +def first_step_future_discounted_returns(rewards, + values, + step_types, + discounts, + time_major=True): + """Computes future 1 to n step discounted returns for the first step. - return = next_reward + next_discount * next_value if is not the last step; - otherwise will set return = current_discount * current_value. + Define abbreviations: + + - B: batch size representing number of trajectories + - T: number of steps per trajectory - Note: Input tensors must be time major Args: rewards (Tensor): shape is [T, B] (or [T]) representing rewards. - values (Tensor): shape is [T, B] (or [T]) when representing values, - [T, B, n_quantiles] or [T, n_quantiles] when representing quantiles - of value distributions. - step_types (Tensor): shape is [T, B] (or [T]) representing step types. + values (Tensor): shape is [T,B] (or [T]) representing values. + step_types (Tensor): shape is [T,B] (or [T]) representing step types. discounts (Tensor): shape is [T, B] (or [T]) representing discounts. + time_major (bool): Whether input tensors are time major. + False means input tensors have shape [B, T]. + Returns: A tensor with shape [T-1, B] (or [T-1]) representing the discounted - returns. + returns. Shape is [B, T-1] when time_major is false. """ + if not time_major: + discounts = discounts.transpose(0, 1) + rewards = rewards.transpose(0, 1) + values = values.transpose(0, 1) + step_types = step_types.transpose(0, 1) + assert values.shape[0] >= 2, ("The sequence length needs to be " "at least 2. Got {s}".format( s=values.shape[0])) @@ -205,56 +283,50 @@ def one_step_discounted_return(rewards, values, step_types, discounts): is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32) is_lasts = common.expand_dims_as(is_lasts, values) discounts = common.expand_dims_as(discounts, values) - rewards = common.expand_dims_as(rewards, values) - discounted_values = discounts * values - rets = (1 - is_lasts[:-1]) * (rewards[1:] + discounted_values[1:]) + \ - is_lasts[:-1] * discounted_values[:-1] - return rets.detach() + accw = torch.ones_like(values) + accw[0] = (1 - is_lasts[0]) * discounts[1] + rets = torch.zeros_like(values) + rets[0] = rewards[1] * (1 - is_lasts[0]) + accw[0] * values[1] + # When ith is LAST, v[i+1] shouldn't be used in computing ret[i]. When disc[i] == 0, v[i] isn't used in computing ret[i-1]. + # when 2nd is LAST, ret[0] = r[1] + disc[1] * v[1], ret[1] = r[1] + disc[1] * (r[2] + disc[2] * v[2]), ret[2] = r[1] + disc[1] * (r[2] + disc[2] * v[2]) + # r[t] = (1 - is_last[t]) * reward[t + 1] + # acc_return_to[t] = acc_return_to[t - 1] + r[t] + # bootstrapped_return[t] = r[t] + (1 - is_last[t + 1]) * discounts[t + 1] * v[t + 1] + with torch.no_grad(): + for t in range(rewards.shape[0] - 2): + accw[t + 1] = accw[t] * (1 - is_lasts[t + 1]) * discounts[t + 2] + rets[t + 1] = ( + rets[t] + rewards[t + 2] * (1 - is_lasts[t + 1]) * accw[t] + + values[t + 2] * accw[t + 1] - + accw[t] * values[t + 1] * (1 - is_lasts[t + 1])) + rets = rets[:-1] -def generalized_advantage_estimation(rewards, - values, - step_types, - discounts, - td_lambda=1.0, - time_major=True): - """Computes generalized advantage estimation (GAE) for the first T-1 steps. + if not time_major: + rets = rets.transpose(0, 1) - For theory, see - "High-Dimensional Continuous Control Using Generalized Advantage Estimation" - by John Schulman, Philipp Moritz et al. - See https://arxiv.org/abs/1506.02438 for full paper. + return rets.detach() - The difference between this function and the one tf_agents.utils.value_ops - is that the accumulated_td is reset to 0 for is_last steps in this function. - Define abbreviations: +def one_step_discounted_return(rewards, values, step_types, discounts): + """Calculate the one step discounted return for the first T-1 steps. - - B: batch size representing number of trajectories - - T: number of steps per trajectory + return = next_reward + next_discount * next_value if is not the last step; + otherwise will set return = current_discount * current_value. + Note: Input tensors must be time major Args: rewards (Tensor): shape is [T, B] (or [T]) representing rewards. - values (Tensor): shape is [T,B] (or [T]) representing values. - step_types (Tensor): shape is [T,B] (or [T]) representing step types. + values (Tensor): shape is [T, B] (or [T]) when representing values, + [T, B, n_quantiles] or [T, n_quantiles] when representing quantiles + of value distributions. + step_types (Tensor): shape is [T, B] (or [T]) representing step types. discounts (Tensor): shape is [T, B] (or [T]) representing discounts. - td_lambda (float): A scalar between [0, 1]. It's used for variance - reduction in temporal difference. - time_major (bool): Whether input tensors are time major. - False means input tensors have shape [B, T]. - Returns: - A tensor with shape [T-1, B] representing advantages. Shape is [B, T-1] - when time_major is false. + A tensor with shape [T-1, B] (or [T-1]) representing the discounted + returns. """ - - if not time_major: - discounts = discounts.transpose(0, 1) - rewards = rewards.transpose(0, 1) - values = values.transpose(0, 1) - step_types = step_types.transpose(0, 1) - assert values.shape[0] >= 2, ("The sequence length needs to be " "at least 2. Got {s}".format( s=values.shape[0])) @@ -262,19 +334,9 @@ def generalized_advantage_estimation(rewards, is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32) is_lasts = common.expand_dims_as(is_lasts, values) discounts = common.expand_dims_as(discounts, values) + rewards = common.expand_dims_as(rewards, values) - weighted_discounts = discounts[1:] * td_lambda - - advs = torch.zeros_like(values) - delta = rewards[1:] + discounts[1:] * values[1:] - values[:-1] - - with torch.no_grad(): - for t in reversed(range(rewards.shape[0] - 1)): - advs[t] = (1 - is_lasts[t]) * \ - (delta[t] + weighted_discounts[t] * advs[t + 1]) - advs = advs[:-1] - - if not time_major: - advs = advs.transpose(0, 1) - - return advs.detach() + discounted_values = discounts * values + rets = (1 - is_lasts[:-1]) * (rewards[1:] + discounted_values[1:]) + \ + is_lasts[:-1] * discounted_values[:-1] + return rets.detach() diff --git a/alf/utils/value_ops_test.py b/alf/utils/value_ops_test.py index ebd526127..6477edbb2 100644 --- a/alf/utils/value_ops_test.py +++ b/alf/utils/value_ops_test.py @@ -23,23 +23,46 @@ class DiscountedReturnTest(unittest.TestCase): """Tests for alf.utils.value_ops.discounted_return """ - def _check(self, rewards, values, step_types, discounts, expected): - np.testing.assert_array_almost_equal( - value_ops.discounted_return( + def _check(self, + rewards, + values, + step_types, + discounts, + expected, + future=False): + if future: + res = value_ops.first_step_future_discounted_returns( rewards=rewards, values=values, step_types=step_types, discounts=discounts, - time_major=False), expected) + time_major=False) + else: + res = value_ops.discounted_return( + rewards=rewards, + values=values, + step_types=step_types, + discounts=discounts, + time_major=False) - np.testing.assert_array_almost_equal( - value_ops.discounted_return( + np.testing.assert_array_almost_equal(res, expected) + + if future: + res = value_ops.first_step_future_discounted_returns( rewards=torch.stack([rewards, 2 * rewards], dim=2), values=torch.stack([values, 2 * values], dim=2), step_types=step_types, discounts=discounts, - time_major=False), torch.stack([expected, 2 * expected], - dim=2)) + time_major=False) + else: + res = value_ops.discounted_return( + rewards=torch.stack([rewards, 2 * rewards], dim=2), + values=torch.stack([values, 2 * values], dim=2), + step_types=step_types, + discounts=discounts, + time_major=False) + np.testing.assert_array_almost_equal( + res, torch.stack([expected, 2 * expected], dim=2)) def test_discounted_return(self): values = torch.tensor([[1.] * 5], dtype=torch.float32) @@ -74,7 +97,7 @@ def test_discounted_return(self): discounts=discounts, expected=expected) - # tow episodes, and end normal (discount=0) + # two episodes, and end normal (discount=0) step_types = torch.tensor([[ StepType.MID, StepType.MID, StepType.LAST, StepType.MID, StepType.MID @@ -91,6 +114,100 @@ def test_discounted_return(self): discounts=discounts, expected=expected) + def test_first_step_future_discounted_returns(self): + values = torch.tensor([[1.] * 5], dtype=torch.float32) + step_types = torch.tensor([[StepType.MID] * 5], dtype=torch.int64) + rewards = torch.tensor([[2.] * 5], dtype=torch.float32) + discounts = torch.tensor([[0.9] * 5], dtype=torch.float32) + expected = torch.tensor([[ + 2 + 0.9, 2 + 0.9 * (2 + 0.9), 2 + 0.9 * (2 + 0.9 * (2 + 0.9)), + 2 + 0.9 * (2 + 0.9 * (2 + 0.9 * (2 + 0.9))) + ]], + dtype=torch.float32) + self._check( + rewards=rewards, + values=values, + step_types=step_types, + discounts=discounts, + expected=expected, + future=True) + + # two episodes, and exceed by time limit (discount=1) + step_types = torch.tensor([[ + StepType.MID, StepType.MID, StepType.LAST, StepType.MID, + StepType.MID + ]], + dtype=torch.int32) + expected = torch.tensor([[ + 2 + 0.9, 2 + 0.9 * (2 + 0.9), 2 + 0.9 * (2 + 0.9), + 2 + 0.9 * (2 + 0.9) + ]], + dtype=torch.float32) + self._check( + rewards=rewards, + values=values, + step_types=step_types, + discounts=discounts, + expected=expected, + future=True) + + # two episodes, and end normal (discount=0) + step_types = torch.tensor([[ + StepType.MID, StepType.MID, StepType.LAST, StepType.MID, + StepType.MID + ]], + dtype=torch.int32) + discounts = torch.tensor([[0.9, 0.9, 0.0, 0.9, 0.9]]) + expected = torch.tensor( + [[2 + 0.9, 2 + 0.9 * 2, 2 + 0.9 * 2, 2 + 0.9 * 2]], + dtype=torch.float32) + + self._check( + rewards=rewards, + values=values, + step_types=step_types, + discounts=discounts, + expected=expected, + future=True) + + # two episodes with discount 0 LAST. + values = torch.tensor([[1.] * 5], dtype=torch.float32) + step_types = torch.tensor([[ + StepType.MID, StepType.LAST, StepType.LAST, StepType.MID, + StepType.MID + ]], + dtype=torch.int32) + rewards = torch.tensor([[2.] * 5], dtype=torch.float32) + discounts = torch.tensor([[0.9, 0.0, 0.0, 0.9, 0.9]]) + expected = torch.tensor([[2, 2, 2, 2]], dtype=torch.float32) + + self._check( + rewards=rewards, + values=values, + step_types=step_types, + discounts=discounts, + expected=expected, + future=True) + + # two episodes with discount 0 LAST. + values = torch.tensor([[1.] * 5], dtype=torch.float32) + step_types = torch.tensor([[ + StepType.LAST, StepType.LAST, StepType.LAST, StepType.MID, + StepType.MID + ]], + dtype=torch.int32) + rewards = torch.tensor([[2.] * 5], dtype=torch.float32) + discounts = torch.tensor([[0.0, 0.0, 0.0, 0.9, 0.9]]) + expected = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32) + + self._check( + rewards=rewards, + values=values, + step_types=step_types, + discounts=discounts, + expected=expected, + future=True) + class GeneralizedAdvantageTest(unittest.TestCase): """Tests for alf.utils.value_ops.generalized_advantage_estimation