-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
draft_retrace #695
base: pytorch
Are you sure you want to change the base?
draft_retrace #695
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take a look at
https://github.com/HorizonRobotics/alf/blob/pytorch/docs/contributing.rst
to format your code properly
alf/algorithms/td_loss.py
Outdated
@@ -99,15 +102,37 @@ def forward(self, experience, value, target_value): | |||
values=target_value, | |||
step_types=experience.step_type, | |||
discounts=experience.discount * self._gamma) | |||
else: | |||
elif train_info == None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking whether train_info is None, you should add an argument in __init__
to indicate whether use retrace.
You should also change SarsaAlgorithm and SacAlgorithm to pass in train_info.
alf/utils/value_ops.py
Outdated
@@ -255,3 +255,36 @@ def generalized_advantage_estimation(rewards, | |||
advs = advs.transpose(0, 1) | |||
|
|||
return advs.detach() | |||
####### add for the retrace method | |||
def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please comment following the way of other functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need unittest for this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- line too long
- add space after
,
- comments for the function need to be added
alf/algorithms/sarsa_algorithm.py
Outdated
@@ -435,7 +435,7 @@ def calc_loss(self, experience, info: SarsaInfo): | |||
target_critic = tensor_utils.tensor_prepend_zero( | |||
info.target_critics) | |||
loss_info = self._critic_losses[i](shifted_experience, critic, | |||
target_critic) | |||
target_critic,info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add space after ,
alf/algorithms/td_loss.py
Outdated
@@ -31,6 +31,10 @@ def __init__(self, | |||
td_error_loss_fn=element_wise_squared_loss, | |||
td_lambda=0.95, | |||
normalize_target=False, | |||
some-feature-retrace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to be removed
alf/algorithms/td_loss.py
Outdated
some-feature-retrace | ||
use_retrace=0, | ||
|
||
pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to be removed
alf/algorithms/td_loss.py
Outdated
@@ -76,8 +80,13 @@ def __init__(self, | |||
self._debug_summaries = debug_summaries | |||
self._normalize_target = normalize_target | |||
self._target_normalizer = None | |||
some-feature-retrace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove, seems to be the tags from a merge
alf/algorithms/td_loss.py
Outdated
|
||
def forward(self, experience, value, target_value): | ||
pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
else: | ||
scope = alf.summary.scope(self.__class__.__name__) | ||
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( | ||
action_distribution=train_info.action_distribution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format, line is too long
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fixed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be many format issues. You may need to follow the workflow here to setup the formatting tools and also get a reference of coding standard:
https://alf.readthedocs.io/en/latest/contributing.html#workflow
alf/algorithms/td_loss.py
Outdated
@@ -46,7 +50,7 @@ def __init__(self, | |||
:math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)` | |||
where the generalized advantage estimation is defined as: | |||
:math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))` | |||
|
|||
use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can change use_retrace
use bool
value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update comment
alf/algorithms/td_loss.py
Outdated
|
||
else: | ||
scope = alf.summary.scope(self.__class__.__name__) | ||
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add space after ,
alf/utils/value_ops.py
Outdated
@@ -255,3 +255,36 @@ def generalized_advantage_estimation(rewards, | |||
advs = advs.transpose(0, 1) | |||
|
|||
return advs.detach() | |||
####### add for the retrace method | |||
def generalized_advantage_estimation_retrace(importance_ratio, discounts, rewards, td_lambda, time_major, values, target_value,step_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- line too long
- add space after
,
- comments for the function need to be added
@@ -170,7 +170,32 @@ def test_generalized_advantage_estimation(self): | |||
discounts=discounts, | |||
td_lambda=td_lambda, | |||
expected=expected) | |||
|
|||
class GeneralizedAdvantage_retrace_Test(unittest.TestCase): | |||
"""Tests for alf.utils.value_ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments not correct
alf/algorithms/td_loss.py
Outdated
@@ -46,7 +50,7 @@ def __init__(self, | |||
:math:`G_t^\lambda = \hat{A}^{GAE}_t + V(s_t)` | |||
where the generalized advantage estimation is defined as: | |||
:math:`\hat{A}^{GAE}_t = \sum_{i=t}^{T-1}(\gamma\lambda)^{i-t}(R_{i+1} + \gamma V(s_{i+1}) - V(s_i))` | |||
|
|||
use_retrace = 0 means one step or multi_step loss, use_retrace = 1 means retrace loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update comment
alf/algorithms/td_loss.py
Outdated
log_prob_clipping=0.0, | ||
scope=scope, | ||
check_numerics=False, | ||
debug_summaries=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug_summaries= debug_summaries
|
||
|
||
####### add for the retrace method | ||
def generalized_advantage_estimation_retrace(importance_ratio, discounts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be merged with generalized_advantage_estimation
function
else: | ||
scope = alf.summary.scope(self.__class__.__name__) | ||
importance_ratio,importance_ratio_clipped = value_ops.action_importance_ratio( | ||
action_distribution=train_info.action_distribution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fixed?
alf/algorithms/td_loss.py
Outdated
@@ -91,6 +97,8 @@ def forward(self, experience, value, target_value): | |||
target_value (torch.Tensor): the time-major tensor for the value at | |||
each time step. This is used to calculate return. ``target_value`` | |||
can be same as ``value``. | |||
train_info (sarsa info, sac info): information used to calcuate importance_ratio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is sarsa info, sac info here? Can this function be used with other algorithms beyond sac and sarsa?
Change code in file value_ops and td_loss. Default value for train_info is None. If we give the train_info parameter and lambda is not equal to 1 and 0, we will use retrace method. So we do not need to change the code of sac_algorithm or sarsa_algorithm when other people do not want retrace method.