Skip to content

Commit

Permalink
fix(nyz): fix api doc bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Oct 16, 2023
1 parent 2aa3165 commit e6eea3d
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 49 deletions.
17 changes: 9 additions & 8 deletions ding/model/wrapper/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ def __init__(
Arguments:
- model(:obj:`Any`): Wrapped model class, should contain forward method.
- state_num (:obj:`int`): Number of states to process.
- save_prev_state (:obj:`bool`): Whether to output the prev state in output['prev_state'].
- init_fn (:obj:`Callable`): The function which is used to init every hidden state when init and reset. \
Default return None for hidden states.
- save_prev_state (:obj:`bool`): Whether to output the prev state in output.
- init_fn (:obj:`Callable`): The function which is used to init every hidden state when init and reset, \
default return None for hidden states.
.. note::
1. This helper must deal with an actual batch with some parts of samples, e.g: 6 samples of state_num 8.
2. This helper must deal with the single sample state reset.
Expand Down Expand Up @@ -218,11 +219,11 @@ def forward(self,
**kwargs) -> Dict[str, torch.Tensor]:
"""
Arguments:
- input_obs (:obj:`torch.Tensor`): Input observation without sequence shape: (bs, *obs_shape)
- only_last_logit (:obj:`bool`): if True 'logit' only contains the output corresponding to the current
observation (shape: bs, embedding_dim), otherwise logit has shape (seq_len, bs, embedding_dim)
- data_id (:obj:`List`): id of the envs that are currently running. Memory update and logits return has only
effect for those environments. If `None` it is considered that all envs are running.
- input_obs (:obj:`torch.Tensor`): Input observation without sequence shape: ``(bs, *obs_shape)``.
- only_last_logit (:obj:`bool`): if True 'logit' only contains the output corresponding to the current \
observation (shape: bs, embedding_dim), otherwise logit has shape (seq_len, bs, embedding_dim).
- data_id (:obj:`List`): id of the envs that are currently running. Memory update and logits return has \
only effect for those environments. If `None` it is considered that all envs are running.
Returns:
- Dictionary containing the input_sequence 'input_seq' stored in memory and the transformer output 'logit'.
"""
Expand Down
5 changes: 2 additions & 3 deletions ding/rl_utils/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def acer_policy_error(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Get ACER policy loss
Get ACER policy loss.
Arguments:
- q_values (:obj:`torch.Tensor`): Q values
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
Expand Down Expand Up @@ -64,7 +64,7 @@ def acer_policy_error(
def acer_value_error(q_values, q_retraces, actions):
"""
Overview:
Get ACER critic loss
Get ACER critic loss.
Arguments:
- q_values (:obj:`torch.Tensor`): Q values
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
Expand All @@ -78,7 +78,6 @@ def acer_value_error(q_values, q_retraces, actions):
- actions (:obj:`torch.LongTensor`): :math:`(T, B)`
- critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
Examples:
>>> q_values=torch.randn(2, 3, 4)
>>> q_retraces=torch.randn(2, 3, 1)
>>> actions=torch.randint(0, 4, (2, 3))
Expand Down
4 changes: 2 additions & 2 deletions ding/rl_utils/ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def ppg_joint_error(
- action (:obj:`torch.LongTensor`): :math:`(B,)`
- value_new (:obj:`torch.FloatTensor`): :math:`(B, 1)`
- value_old (:obj:`torch.FloatTensor`): :math:`(B, 1)`
- return_ (:obj:`torch.FloatTensor`): :math:`(B, 1)`
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B,)`
- return (:obj:`torch.FloatTensor`): :math:`(B, 1)`
- weight (:obj:`torch.FloatTensor`): :math:`(B,)`
- auxiliary_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
- behavioral_cloning_loss (:obj:`torch.FloatTensor`): :math:`()`
Examples:
Expand Down
17 changes: 6 additions & 11 deletions ding/rl_utils/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_
- nstep (:obj:`int`): nstep num
- value_gamma (:obj:`torch.Tensor`): Discount factor for value
Returns:
- return_ (:obj:`torch.Tensor`): nstep return
- return (:obj:`torch.Tensor`): nstep return
Shapes:
- data (:obj:`nstep_return_data`): the nstep_return_data containing\
['reward', 'next_value', 'done']
Expand Down Expand Up @@ -718,15 +718,10 @@ def bdq_nstep_td_error(
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for BDQ algorithm, \
referenced paper Action Branching Architectures for Deep Reinforcement Learning \
<https://arxiv.org/pdf/1711.08946>
In fact, the original paper only provides the 1-step TD-error calculation method, \
and here we extend the calculation method of n-step.
TD-error:
y_d = \sigma_{t=0}^{nstep} \gamma^t * r_t + \gamma^{nstep} * Q_d'(s', argmax Q_d(s', a_d))
TD-error = \frac{1}{D} * (y_d - Q_d(s, a_d))^2
Loss = mean(TD-error)
Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures \
for Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946.
In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \
calculation method of n-step TD-error.
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
Expand All @@ -738,7 +733,7 @@ def bdq_nstep_td_error(
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)`
Expand Down
8 changes: 4 additions & 4 deletions ding/rl_utils/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamm
Shapes:
- clipped_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size
- clipped_cs (:obj:`torch.FloatTensor`): :math:`(T, B)`
- reward: (:obj:`torch.FloatTensor`): :math:`(T, B)`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
- bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
- vtrace_return (:obj:`torch.FloatTensor`): :math:`(T, B)`
"""
Expand All @@ -37,10 +37,10 @@ def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma):
- vtrace_advantage (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor
Shapes:
- clipped_pg_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size
- reward: (:obj:`torch.FloatTensor`): :math:`(T, B)`
- return_ (:obj:`torch.FloatTensor`): :math:`(T, B)`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
- return (:obj:`torch.FloatTensor`): :math:`(T, B)`
- bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T, B)`
- vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)`
- vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)`
"""
return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values)

Expand Down
29 changes: 8 additions & 21 deletions ding/torch_utils/network/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F


class Lambda(nn.Module):
Expand All @@ -16,30 +15,18 @@ def forward(self, x):


class GLU(nn.Module):
r"""
"""
Overview:
Gating Linear Unit.
This class does a thing like this:
.. code:: python
# Inputs: input, context, output_size
# The gate value is a learnt function of the input.
gate = sigmoid(linear(input.size)(context))
# Gate the input and return an output of desired size.
gated_input = gate * input
output = linear(output_size)(gated_input)
return output
Interfaces:
forward
``forward``.
.. tip::
This module also supports 2D convolution, in which case, the input and context must have the same shape.
"""

def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None:
r"""
"""
Overview:
Init GLU
Arguments:
Expand Down Expand Up @@ -101,15 +88,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def build_activation(activation: str, inplace: bool = None) -> nn.Module:
r"""
"""
Overview:
Return the activation module according to the given type.
Arguments:
- activation (:obj:`str`): the type of activation module, now supports \
['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity']
- inplace (:obj:`bool`): can optionally do the operation in-place in relu. Default ``None``
- activation (:obj:`str`): The type of activation module, now supports \
['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity'].
- inplace (:obj:`bool`): Execute the operation in-place in activation, defaults to ``None``.
Returns:
- act_func (:obj:`nn.module`): the corresponding activation module
- act_func (:obj:`nn.module`): The corresponding activation module.
"""
if inplace is not None:
assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation)
Expand Down

0 comments on commit e6eea3d

Please sign in to comment.