Skip to content

Commit

Permalink
minimum change for lower bounded value target (for episodic return, g…
Browse files Browse the repository at this point in the history
…oal distance return, and n-step bootstrapped return)
  • Loading branch information
Le Horizon committed Jun 27, 2022
1 parent dde33cd commit b9c4143
Show file tree
Hide file tree
Showing 11 changed files with 648 additions and 119 deletions.
57 changes: 39 additions & 18 deletions alf/algorithms/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,14 +736,25 @@ class HindsightExperienceTransformer(DataTransformer):
of the current timestep.
The exact field names can be provided via arguments to the class ``__init__``.
NOTE: The HindsightExperienceTransformer has to happen before any transformer which changes
reward or achieved_goal fields, e.g. observation normalizer, reward clipper, etc..
See `documentation <../../docs/notes/knowledge_base.rst#datatransformers>`_ for details.
To use this class, add it to any existing data transformers, e.g. use this config if
``ObservationNormalizer`` is an existing data transformer:
.. code-block:: python
ReplayBuffer.keep_episodic_info=True
HindsightExperienceTransformer.her_proportion=0.8
TrainerConfig.data_transformer_ctor=[@HindsightExperienceTransformer, @ObservationNormalizer]
alf.config('ReplayBuffer', keep_episodic_info=True)
alf.config(
'HindsightExperienceTransformer',
her_proportion=0.8
)
alf.config(
'TrainerConfig',
data_transformer_ctor=[
HindsightExperienceTransformer, ObservationNormalizer
])
See unit test for more details on behavior.
"""
Expand Down Expand Up @@ -820,9 +831,10 @@ def transform_experience(self, experience: Experience):
# relabel only these sampled indices
her_cond = torch.rand(batch_size) < her_proportion
(her_indices, ) = torch.where(her_cond)
has_her = torch.any(her_cond)

last_step_pos = start_pos[her_indices] + batch_length - 1
last_env_ids = env_ids[her_indices]
last_step_pos = start_pos + batch_length - 1
last_env_ids = env_ids
# Get x, y indices of LAST steps
dist = buffer.steps_to_episode_end(last_step_pos, last_env_ids)
if alf.summary.should_record_summaries():
Expand All @@ -831,22 +843,24 @@ def transform_experience(self, experience: Experience):
torch.mean(dist.type(torch.float32)))

# get random future state
future_idx = last_step_pos + (torch.rand(*dist.shape) *
(dist + 1)).to(torch.int64)
future_dist = (torch.rand(*dist.shape) * (dist + 1)).to(
torch.int64)
future_idx = last_step_pos + future_dist
future_ag = buffer.get_field(self._achieved_goal_field,
last_env_ids, future_idx).unsqueeze(1)

# relabel desired goal
result_desired_goal = alf.nest.get_field(result,
self._desired_goal_field)
relabed_goal = result_desired_goal.clone()
relabeled_goal = result_desired_goal.clone()
her_batch_index_tuple = (her_indices.unsqueeze(1),
torch.arange(batch_length).unsqueeze(0))
relabed_goal[her_batch_index_tuple] = future_ag
if has_her:
relabeled_goal[her_batch_index_tuple] = future_ag[her_indices]

# recompute rewards
result_ag = alf.nest.get_field(result, self._achieved_goal_field)
relabeled_rewards = self._reward_fn(result_ag, relabed_goal)
relabeled_rewards = self._reward_fn(result_ag, relabeled_goal)

non_her_or_fst = ~her_cond.unsqueeze(1) & (result.step_type !=
StepType.FIRST)
Expand Down Expand Up @@ -876,21 +890,28 @@ def transform_experience(self, experience: Experience):
alf.summary.scalar(
"replayer/" + buffer._name + ".reward_mean_before_relabel",
torch.mean(result.reward[her_indices][:-1]))
alf.summary.scalar(
"replayer/" + buffer._name + ".reward_mean_after_relabel",
torch.mean(relabeled_rewards[her_indices][:-1]))
if has_her:
alf.summary.scalar(
"replayer/" + buffer._name + ".reward_mean_after_relabel",
torch.mean(relabeled_rewards[her_indices][:-1]))
alf.summary.scalar("replayer/" + buffer._name + ".future_distance",
torch.mean(future_dist.float()))

result = alf.nest.transform_nest(
result, self._desired_goal_field, lambda _: relabed_goal)

result, self._desired_goal_field, lambda _: relabeled_goal)
result = result.update_time_step_field('reward', relabeled_rewards)

info = info._replace(her=her_cond, future_distance=future_dist)
if alf.get_default_device() != buffer.device:
for f in accessed_fields:
result = alf.nest.transform_nest(
result, f, lambda t: convert_device(t))
result = alf.nest.transform_nest(
result, "batch_info.replay_buffer", lambda _: buffer)
info = convert_device(info)
info = info._replace(
her=info.her.unsqueeze(1).expand(result.reward.shape[:2]),
future_distance=info.future_distance.unsqueeze(1).expand(
result.reward.shape[:2]),
replay_buffer=buffer)
result = alf.data_structures.add_batch_info(result, info)
return result


Expand Down
17 changes: 14 additions & 3 deletions alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,20 @@
DdpgActorState = namedtuple("DdpgActorState", ['actor', 'critics'])
DdpgState = namedtuple("DdpgState", ['actor', 'critics'])
DdpgInfo = namedtuple(
"DdpgInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor_loss", "critic", "discounted_return"
"DdpgInfo",
[
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor_loss",
"critic",
# Optional fields for value target lower bounding or Hindsight relabeling.
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
"discounted_return",
"future_distance",
"her"
],
default_value=())
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))
Expand Down
4 changes: 2 additions & 2 deletions alf/algorithms/one_step_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from typing import Union, List, Callable

import alf
from alf.algorithms.td_loss import TDLoss, TDQRLoss
from alf.algorithms.td_loss import LowerBoundedTDLoss, TDQRLoss
from alf.utils import losses


@alf.configurable
class OneStepTDLoss(TDLoss):
class OneStepTDLoss(LowerBoundedTDLoss):
def __init__(self,
gamma: Union[float, List[float]] = 0.99,
td_error_loss_fn: Callable = losses.element_wise_squared_loss,
Expand Down
19 changes: 16 additions & 3 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,22 @@
"SacActorInfo", ["actor_loss", "neg_entropy"], default_value=())

SacInfo = namedtuple(
"SacInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor", "critic", "alpha", "log_pi", "discounted_return"
"SacInfo",
[
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor",
"critic",
"alpha",
"log_pi",
# Optional fields for value target lower bounding or Hindsight relabeling.
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
"discounted_return",
"future_distance",
"her"
],
default_value=())

Expand Down
Loading

0 comments on commit b9c4143

Please sign in to comment.