diff --git a/ding/model/common/encoder.py b/ding/model/common/encoder.py index e22112601e..82dab4808a 100644 --- a/ding/model/common/encoder.py +++ b/ding/model/common/encoder.py @@ -23,7 +23,7 @@ def prod(iterable): class ConvEncoder(nn.Module): """ Overview: - The ``Convolution Encoder`` used to encode raw 2-dim image observations (e.g. Atari/Procgen). + The Convolution Encoder is used to encode 2-dim image observations. Interfaces: ``__init__``, ``forward``. """ @@ -106,6 +106,18 @@ def _get_flatten_size(self) -> int: - outputs (:obj:`torch.Tensor`): Size ``int`` Tensor representing the number of ``in-features``. Shapes: - outputs: :math:`(1,)`. + Examples: + >>> conv = ConvEncoder( + >>> obs_shape=(4, 84, 84), + >>> hidden_size_list=[32, 64, 64, 128], + >>> activation=nn.ReLU(), + >>> kernel_size=[8, 4, 3], + >>> stride=[4, 2, 1], + >>> padding=None, + >>> layer_norm=False, + >>> norm_type=None + >>> ) + >>> flatten_size = conv._get_flatten_size() """ test_data = torch.randn(1, *self.obs_shape) with torch.no_grad(): @@ -123,6 +135,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Shapes: - x : :math:`(B, C, H, W)`, where ``B`` is batch size, ``C`` is channel, ``H`` is height, ``W`` is width. - outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]`` . + Examples: + >>> conv = ConvEncoder( + >>> obs_shape=(4, 84, 84), + >>> hidden_size_list=[32, 64, 64, 128], + >>> activation=nn.ReLU(), + >>> kernel_size=[8, 4, 3], + >>> stride=[4, 2, 1], + >>> padding=None, + >>> layer_norm=False, + >>> norm_type=None + >>> ) + >>> x = torch.randn(1, 4, 84, 84) + >>> output = conv(x) """ x = self.main(x) x = self.mid(x) @@ -132,7 +157,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FCEncoder(nn.Module): """ Overview: - The ``FCEncoder`` used in models to encode raw 1-dim observations (e.g. MuJoCo). + The full connected encoder is used to encode 1-dim input variable. Interfaces: ``__init__``, ``forward``. """ @@ -148,7 +173,7 @@ def __init__( ) -> None: """ Overview: - Init the FC Encoder according to arguments. + Initialize the FC Encoder according to arguments. Arguments: - obs_shape (:obj:`int`): Observation shape. - hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent FC layers. @@ -194,6 +219,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Shapes: - x : :math:`(B, M)`, where ``M = obs_shape``. - outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]``. + Examples: + >>> fc = FCEncoder( + >>> obs_shape=4, + >>> hidden_size_list=[32, 64, 64, 128], + >>> activation=nn.ReLU(), + >>> norm_type=None, + >>> dropout=None + >>> ) + >>> x = torch.randn(1, 4) + >>> output = fc(x) """ x = self.act(self.init(x)) x = self.main(x) @@ -211,13 +246,18 @@ def __init__(self, obs_shape: Dict[str, Union[int, List[int]]]) -> None: class IMPALACnnResidualBlock(nn.Module): """ Overview: - Residual basic block (without batchnorm) in IMPALA CNN encoder, which preserves the channel number and shape. + This CNN encoder residual block is residual basic block used in IMPALA algorithm, + which preserves the channel number and shape. + IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures + https://arxiv.org/pdf/1802.01561.pdf + Interfaces: + ``__init__``, ``forward``. """ def __init__(self, in_channnel: int, scale: float = 1, batch_norm: bool = False): """ Overview: - Init the IMPALA CNN residual block according to arguments. + Initialize the IMPALA CNN residual block according to arguments. Arguments: - in_channnel (:obj:`int`): Channel number of input features. - scale (:obj:`float`): Scale of module, defaults to 1. @@ -234,9 +274,16 @@ def __init__(self, in_channnel: int, scale: float = 1, batch_norm: bool = False) self.bn1 = nn.BatchNorm2d(self.in_channnel) def residual(self, x: torch.Tensor) -> torch.Tensor: - # inplace should be False for the first relu, so that it does not change the input, - # which will be used for skip connection. - # getattr is for backwards compatibility with loaded models + """ + Overview: + Return output tensor of the residual block, keep the shape and channel number unchanged. + The inplace of activation function should be False for the first relu, + so that it does not change the origin input tensor of the residual block. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. + Returns: + - output (:obj:`torch.Tensor`): Output tensor. + """ if self.batch_norm: x = self.bn0(x) x = F.relu(x, inplace=False) @@ -255,6 +302,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: - x (:obj:`torch.Tensor`): Input tensor. Returns: - output (:obj:`torch.Tensor`): Output tensor. + Examples: + >>> block = IMPALACnnResidualBlock(16) + >>> x = torch.randn(1, 16, 84, 84) + >>> output = block(x) """ return x + self.residual(x) @@ -262,13 +313,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class IMPALACnnDownStack(nn.Module): """ Overview: - Downsampling stack from IMPALA CNN encoder, which reduces the spatial size by 2 with maxpooling. + Downsampling stack of CNN encoder used in IMPALA algorithmn. + Every IMPALACnnDownStack consists n IMPALACnnResidualBlock, + which reduces the spatial size by 2 with maxpooling. + IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures + https://arxiv.org/pdf/1802.01561.pdf + Interfaces: + ``__init__``, ``forward``. """ def __init__(self, in_channnel, nblock, out_channel, scale=1, pool=True, **kwargs): """ Overview: - Init every impala cnn block of the Impala Cnn Encoder. + Initialize every impala cnn block of the Impala Cnn Encoder. Arguments: - in_channnel (:obj:`int`): Channel number of input features. - nblock (:obj:`int`): Residual Block number in each block. @@ -293,6 +350,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: - x (:obj:`torch.Tensor`): Input tensor. Returns: - output (:obj:`torch.Tensor`): Output tensor. + Examples: + >>> stack = IMPALACnnDownStack(16, 2, 32) + >>> x = torch.randn(1, 16, 84, 84) + >>> output = stack(x) """ x = self.firstconv(x) if self.pool: @@ -305,6 +366,17 @@ def output_shape(self, inshape: tuple) -> tuple: """ Overview: Calculate the output shape of the downsampling stack according to input shape and related arguments. + Arguments: + - inshape (:obj:`tuple`): Input shape. + Returns: + - output_shape (:obj:`tuple`): Output shape. + Shapes: + - inshape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width. + - output_shape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width. + Examples: + >>> stack = IMPALACnnDownStack(16, 2, 32) + >>> inshape = (16, 84, 84) + >>> output_shape = stack.output_shape(inshape) """ c, h, w = inshape assert c == self.in_channnel @@ -317,7 +389,7 @@ def output_shape(self, inshape: tuple) -> tuple: class IMPALAConvEncoder(nn.Module): """ Overview: - IMPALA CNN encoder, which is used in IMPALA algorithm. \ + IMPALA CNN encoder, which is used in IMPALA algorithm. IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, \ https://arxiv.org/pdf/1802.01561.pdf, Interface: @@ -337,7 +409,7 @@ def __init__( ) -> None: """ Overview: - Init the IMPALA CNN encoder according to arguments. + Initialize the IMPALA CNN encoder according to arguments. Arguments: - obs_shape (:obj:`SequenceType`): 2D image observation shape. - channels (:obj:`SequenceType`): The channel number of a series of impala cnn blocks. \ @@ -348,6 +420,7 @@ def __init__( observation, such as dividing 255.0 for the raw image observation. - nblock (:obj:`int`): The number of Residual Block in each block. - final_relu (:obj:`bool`): Whether to use ReLU activation in the final output of encoder. + - kwargs (:obj:`Dict[str, Any]`): Other arguments for ``IMPALACnnDownStack``. """ super().__init__() self.scale_ob = scale_ob @@ -375,6 +448,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: - x (:obj:`torch.Tensor`): :math:`(B, C, H, W)`, where B is batch size, C is channel number, H is height \ and W is width. - output (:obj:`torch.Tensor`): :math:`(B, outsize)`, where B is batch size. + Examples: + >>> encoder = IMPALAConvEncoder( + >>> obs_shape=(4, 84, 84), + >>> channels=(16, 32, 32), + >>> outsize=256, + >>> scale_ob=255.0, + >>> nblock=2, + >>> final_relu=True, + >>> ) + >>> x = torch.randn(1, 4, 84, 84) + >>> output = encoder(x) """ x = x / self.scale_ob for (i, layer) in enumerate(self.stacks): diff --git a/ding/model/common/head.py b/ding/model/common/head.py index c1d27fba89..30f5b58d98 100755 --- a/ding/model/common/head.py +++ b/ding/model/common/head.py @@ -14,8 +14,8 @@ class DiscreteHead(nn.Module): """ Overview: - The ``DiscreteHead`` used to output discrete actions logit or Q-value logit, which is often used in DQN \ - and policy head in actor-critic algorithms for discrete action space. + The ``DiscreteHead`` is used to generate discrete actions logit or Q-value logit, \ + which is often used in q-learning algorithms or actor-critic algorithms for discrete action space. Interfaces: ``__init__``, ``forward``. """ @@ -73,7 +73,6 @@ def forward(self, x: torch.Tensor) -> Dict: Shapes: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - logit: :math:`(B, M)`, where ``M = output_size``. - Examples: >>> head = DiscreteHead(64, 64) >>> inputs = torch.randn(4, 64) @@ -87,7 +86,8 @@ def forward(self, x: torch.Tensor) -> Dict: class DistributionHead(nn.Module): """ Overview: - The ``DistributionHead`` used to output Q-value distribution, which is often used in C51 algorithm. + The ``DistributionHead`` is used to generate distribution for Q-value. + This module is used in C51 algorithm. Interfaces: ``__init__``, ``forward``. """ @@ -156,7 +156,6 @@ def forward(self, x: torch.Tensor) -> Dict: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - logit: :math:`(B, M)`, where ``M = output_size``. - distribution: :math:`(B, M, n_atom)`. - Examples: >>> head = DistributionHead(64, 64) >>> inputs = torch.randn(4, 64) @@ -177,7 +176,8 @@ def forward(self, x: torch.Tensor) -> Dict: class BranchingHead(nn.Module): """ Overview: - The ``BranchingHead`` used to output different branches Q-value, which is used in Branch DQN. + The ``BranchingHead`` is used to generate Q-value with different branches. + This module is used in Branch DQN. Interfaces: ``__init__``, ``forward``. """ @@ -267,7 +267,6 @@ def forward(self, x: torch.Tensor) -> Dict: Shapes: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - logit: :math:`(B, M)`, where ``M = output_size``. - Examples: >>> head = BranchingHead(64, 5, 2) >>> inputs = torch.randn(4, 64) @@ -290,7 +289,8 @@ def forward(self, x: torch.Tensor) -> Dict: class RainbowHead(nn.Module): """ Overview: - The ``RainbowHead`` used to output Q-value distribution, which is used in Rainbow DQN. + The ``RainbowHead`` is used to generate distribution of Q-value. + This module is used in Rainbow DQN. Interfaces: ``__init__``, ``forward``. """ @@ -370,7 +370,6 @@ def forward(self, x: torch.Tensor) -> Dict: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - logit: :math:`(B, M)`, where ``M = output_size``. - distribution: :math:`(B, M, n_atom)`. - Examples: >>> head = RainbowHead(64, 64) >>> inputs = torch.randn(4, 64) @@ -394,7 +393,7 @@ def forward(self, x: torch.Tensor) -> Dict: class QRDQNHead(nn.Module): """ Overview: - The ``QRDQNHead`` (Quantile Regression DQN) used to output action quantiles. + The ``QRDQNHead`` (Quantile Regression DQN) is used to output action quantiles. Interfaces: ``__init__``, ``forward``. """ @@ -455,7 +454,6 @@ def forward(self, x: torch.Tensor) -> Dict: - logit: :math:`(B, M)`, where ``M = output_size``. - q: :math:`(B, M, num_quantiles)`. - tau: :math:`(B, M, 1)`. - Examples: >>> head = QRDQNHead(64, 64) >>> inputs = torch.randn(4, 64) @@ -478,9 +476,15 @@ def forward(self, x: torch.Tensor) -> Dict: class QuantileHead(nn.Module): """ Overview: - The ``QuantileHead`` used to output action quantiles, which is used in IQN. + The ``QuantileHead`` is used to output action quantiles. + This module is used in IQN. Interfaces: ``__init__``, ``forward``, ``quantile_net``. + + .. note:: + The difference between ``QuantileHead`` and ``QRDQNHead`` is that ``QuantileHead`` models the \ + state-action quantile function as a mapping from state-actions and samples from some base distribution \ + while ``QRDQNHead`` approximates random returns by a uniform mixture of Diracs functions. """ def __init__( @@ -574,7 +578,6 @@ def forward(self, x: torch.Tensor, num_quantiles: Optional[int] = None) -> Dict: - logit: :math:`(B, M)`, where ``M = output_size``. - q: :math:`(num_quantiles, B, M)`. - quantiles: :math:`(quantile_embedding_size, 1)`. - Examples: >>> head = QuantileHead(64, 64) >>> inputs = torch.randn(4, 64) @@ -609,9 +612,18 @@ def forward(self, x: torch.Tensor, num_quantiles: Optional[int] = None) -> Dict: class FQFHead(nn.Module): """ Overview: - The ``FQFHead`` used to output action quantiles, which is used in ``FQF``. + The ``FQFHead`` is used to output action quantiles. + This module is used in FQF. Interfaces: ``__init__``, ``forward``, ``quantile_net``. + + .. note:: + The implementation of FQFHead is based on the paper https://arxiv.org/abs/1911.02140. + The difference between FQFHead and QuantileHead is that, in FQF, \ + N adjustable quantile values for N adjustable quantile fractions are estimated to approximate \ + the quantile function. The distribution of the return is approximated by a weighted mixture of N \ + Diracs functions. While in IQN, the state-action quantile function is modeled as a mapping from \ + state-actions and samples from some base distribution. """ def __init__( @@ -779,7 +791,8 @@ def forward(self, x: torch.Tensor, num_quantiles: Optional[int] = None) -> Dict: class DuelingHead(nn.Module): """ Overview: - The ``DuelingHead`` used to output discrete actions logit, which is used in Dueling DQN. + The ``DuelingHead`` is used to output discrete actions logit. + This module is used in Dueling DQN. Interfaces: ``__init__``, ``forward``. """ @@ -857,7 +870,6 @@ def forward(self, x: torch.Tensor) -> Dict: Shapes: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - logit: :math:`(B, M)`, where ``M = output_size``. - Examples: >>> head = DuelingHead(64, 64) >>> inputs = torch.randn(4, 64) @@ -874,7 +886,7 @@ def forward(self, x: torch.Tensor) -> Dict: class StochasticDuelingHead(nn.Module): """ Overview: - The ``Stochastic Dueling Network`` proposed in paper ACER (arxiv 1611.01224). \ + The ``Stochastic Dueling Network`` is proposed in paper ACER (arxiv 1611.01224). \ That is to say, dueling network architecture in continuous action space. Interfaces: ``__init__``, ``forward``. @@ -975,6 +987,16 @@ def forward( - sigma: :math:`(B, A)`. - q_value: :math:`(B, 1)`. - v_value: :math:`(B, 1)`. + Examples: + >>> head = StochasticDuelingHead(64, 64) + >>> inputs = torch.randn(4, 64) + >>> a = torch.randn(4, 64) + >>> mu = torch.randn(4, 64) + >>> sigma = torch.ones(4, 64) + >>> outputs = head(inputs, a, mu, sigma) + >>> assert isinstance(outputs, dict) + >>> assert outputs['q_value'].shape == torch.Size([4, 1]) + >>> assert outputs['v_value'].shape == torch.Size([4, 1]) """ batch_size = s.shape[0] # batch_size or batch_size * T @@ -1005,8 +1027,9 @@ def forward( class RegressionHead(nn.Module): """ Overview: - The ``RegressionHead`` used to output continuous actions Q-value (DDPG critic), state value (A2C/PPO), and \ - directly predict continuous action (DDPG actor). + The ``RegressionHead`` is used to regress continuous variables. + This module is used for generating Q-value (DDPG critic) of continuous actions, \ + or state value (A2C/PPO), or directly predicting continuous action (DDPG actor). Interfaces: ``__init__``, ``forward``. """ @@ -1054,7 +1077,6 @@ def forward(self, x: torch.Tensor) -> Dict: Shapes: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - pred: :math:`(B, M)`, where ``M = output_size``. - Examples: >>> head = RegressionHead(64, 64) >>> inputs = torch.randn(4, 64) @@ -1074,7 +1096,9 @@ def forward(self, x: torch.Tensor) -> Dict: class ReparameterizationHead(nn.Module): """ Overview: - The ``ReparameterizationHead`` used to output action ``mu`` and ``sigma``, which is often used in PPO and SAC. + The ``ReparameterizationHead`` is used to generate Gaussian distribution of continuous variable, \ + which is parameterized by ``mu`` and ``sigma``. + This module is often used in stochastic policies, such as PPO and SAC. Interfaces: ``__init__``, ``forward``. """ @@ -1146,7 +1170,6 @@ def forward(self, x: torch.Tensor) -> Dict: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - mu: :math:`(B, M)`, where ``M = output_size``. - sigma: :math:`(B, M)`. - Examples: >>> head = ReparameterizationHead(64, 64, sigma_type='fixed') >>> inputs = torch.randn(4, 64) @@ -1173,7 +1196,10 @@ def forward(self, x: torch.Tensor) -> Dict: class PopArtVHead(nn.Module): """ Overview: - The ``PopArtVHead`` used to output adaptive normalized state value, which is used in PPO/IMPALA. + The ``PopArtVHead`` is used to generate adaptive normalized state value. More information can be found in \ + paper Multi-task Deep Reinforcement Learning with PopArt. \ + https://arxiv.org/abs/1809.04474 \ + This module is used in PPO or IMPALA. Interfaces: ``__init__``, ``forward``. """ @@ -1261,7 +1287,12 @@ def forward(self, key: torch.Tensor, query: torch.Tensor) -> torch.Tensor: ``K = hidden_size``. - query: :math:`(B, K)`. - logit: :math:`(B, N)`. - + Examples: + >>> head = AttentionPolicyHead() + >>> key = torch.randn(4, 5, 64) + >>> query = torch.randn(4, 64) + >>> logit = head(key, query) + >>> assert logit.shape == torch.Size([4, 5]) .. note:: In this head, we assume that the ``key`` and ``query`` tensor are both normalized. """ @@ -1274,8 +1305,8 @@ def forward(self, key: torch.Tensor, query: torch.Tensor) -> torch.Tensor: class MultiHead(nn.Module): """ Overview: - The ``MultiHead`` used to output multiple similar results. For example, we can combine ``Distribution`` and \ - ``MultiHead`` to output multi-discrete action space logit. + The ``MultiHead`` is used to generate multiple similar results. + For example, we can combine ``Distribution`` and ``MultiHead`` to generate multi-discrete action space logit. Interfaces: ``__init__``, ``forward``. """ @@ -1308,7 +1339,6 @@ def forward(self, x: torch.Tensor) -> Dict: Shapes: - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``. - logit: :math:`(B, Mi)`, where ``Mi = output_size`` corresponding to output ``i``. - Examples: >>> head = MultiHead(DuelingHead, 64, [2, 3, 5], v_layer_num=2) >>> inputs = torch.randn(4, 64) @@ -1329,7 +1359,7 @@ def forward(self, x: torch.Tensor) -> Dict: class EnsembleHead(nn.Module): """ Overview: - The ``EnsembleHead`` used to output action Q-value for Q-ensemble in model-based RL algorithms. + The ``EnsembleHead`` is used to generate Q-value for Q-ensemble in model-based RL algorithms. Interfaces: ``__init__``, ``forward``. """ @@ -1400,11 +1430,18 @@ def forward(self, x: torch.Tensor) -> Dict: def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Distribution: """ Overview: - The compatibility function to convert different types logit to independent normal distribution. + Convert different types logit to independent normal distribution. Arguments: - logits (:obj:`Union[List, Dict]`): The logits to be converted. Returns: - dist (:obj:`torch.distributions.Distribution`): The converted normal distribution. + Examples: + >>> logits = [torch.randn(4, 5), torch.ones(4, 5)] + >>> dist = independent_normal_dist(logits) + >>> assert isinstance(dist, torch.distributions.Independent) + >>> assert isinstance(dist.base_dist, torch.distributions.Normal) + >>> assert dist.base_dist.loc.shape == torch.Size([4, 5]) + >>> assert dist.base_dist.scale.shape == torch.Size([4, 5]) Raises: - TypeError: If the type of logits is not ``list`` or ``dict``. """ diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py index 0f508de0b8..0ca8df7fb5 100644 --- a/ding/model/common/utils.py +++ b/ding/model/common/utils.py @@ -13,7 +13,14 @@ def create_model(cfg: EasyDict) -> torch.nn.Module: used to import modules, and they key ``type`` is used to indicate the model. Returns: - (:obj:`torch.nn.Module`): The created neural network model. - + Examples: + >>> cfg = EasyDict({ + >>> 'import_names': ['ding.model.template.q_learning'], + >>> 'type': 'dqn', + >>> 'obs_shape': 4, + >>> 'action_shape': 2, + >>> }) + >>> model = create_model(cfg) .. tip:: This method will not modify the ``cfg`` , it will deepcopy the ``cfg`` and then modify it. """ diff --git a/ding/model/template/acer.py b/ding/model/template/acer.py index 2e28ef0b2c..bb46b22bec 100644 --- a/ding/model/template/acer.py +++ b/ding/model/template/acer.py @@ -9,9 +9,11 @@ @MODEL_REGISTRY.register('acer') class ACER(nn.Module): - r""" + """ Overview: - The ACER model. + The model of algorithmn ACER(Actor Critic with Experience Replay) + Sample Efficient Actor-Critic with Experience Replay. + https://arxiv.org/abs/1611.01224 Interfaces: ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` """ @@ -29,7 +31,7 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, ) -> None: - r""" + """ Overview: Init the ACER Model according to arguments. Arguments: @@ -78,10 +80,10 @@ def __init__( self.critic = nn.ModuleList(self.critic) def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: - r""" + """ Overview: - Use observation to predict output. - Parameter updates with ACER's MLPs forward setup. + Use observation to predict output. + Parameter updates with ACER's MLPs forward setup. Arguments: Forward with ``'compute_actor'``: - inputs (:obj:`torch.Tensor`): @@ -101,11 +103,9 @@ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: Forward with ``'compute_critic'``, Necessary Keys: - q_value (:obj:`torch.Tensor`): Q value tensor. - Actor Shapes: - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` - logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` - Critic Shapes: - inputs (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``obs_shape`` - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` @@ -115,24 +115,16 @@ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: >>> inputs = torch.randn(4, 64) >>> actor_outputs = model(inputs,'compute_actor') >>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) - Critic Examples: >>> inputs = torch.randn(4,N) >>> model = ACER(obs_shape=(N, ),action_shape=5) - >>> model(inputs, mode='compute_critic')['q_value'] # q value - tensor([[-0.0681, -0.0431, -0.0530, 0.1454, -0.1093], - [-0.0647, -0.0281, -0.0527, 0.1409, -0.1162], - [-0.0596, -0.0321, -0.0676, 0.1386, -0.1113], - [-0.0874, -0.0406, -0.0487, 0.1346, -0.1135]], - grad_fn=) - - + >>> model(inputs, mode='compute_critic')['q_value'] """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_actor(self, inputs: torch.Tensor) -> Dict: - r""" + """ Overview: Use encoded embedding tensor to predict output. Execute parameter updates with ``'compute_actor'`` mode @@ -144,7 +136,6 @@ def compute_actor(self, inputs: torch.Tensor) -> Dict: - mode (:obj:`str`): Name of the forward mode. Returns: - outputs (:obj:`Dict`): Outputs of forward pass encoder and head. - ReturnsKeys (either): - logit (:obj:`torch.FloatTensor`): :math:`(B, N1)`, where B is batch size and N1 is ``action_shape`` Shapes: @@ -163,7 +154,7 @@ def compute_actor(self, inputs: torch.Tensor) -> Dict: return x def compute_critic(self, inputs: torch.Tensor) -> Dict: - r""" + """ Overview: Execute parameter updates with ``'compute_critic'`` mode Use encoded embedding tensor to predict output. @@ -172,22 +163,15 @@ def compute_critic(self, inputs: torch.Tensor) -> Dict: - mode (:obj:`str`): Name of the forward mode. Returns: - outputs (:obj:`Dict`): Q-value output. - ReturnKeys: - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. Shapes: - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``. - Examples: >>> inputs =torch.randn(4, N) >>> model = ACER(obs_shape=(N, ),action_shape=5) - >>> model(inputs, mode='compute_critic')['q_value'] # q value - tensor([[-0.0681, -0.0431, -0.0530, 0.1454, -0.1093], - [-0.0647, -0.0281, -0.0527, 0.1409, -0.1162], - [-0.0596, -0.0321, -0.0676, 0.1386, -0.1113], - [-0.0874, -0.0406, -0.0487, 0.1346, -0.1135]], - grad_fn=) + >>> model(inputs, mode='compute_critic')['q_value'] """ obs = inputs diff --git a/ding/model/template/atoc.py b/ding/model/template/atoc.py index f0863481c2..a06f536aef 100644 --- a/ding/model/template/atoc.py +++ b/ding/model/template/atoc.py @@ -5,28 +5,24 @@ from ding.utils import squeeze, MODEL_REGISTRY, SequenceType from ding.torch_utils import MLP -from ..common import RegressionHead +from ding.model.common import RegressionHead class ATOCAttentionUnit(nn.Module): - r""" + """ Overview: - the attention unit of the atoc network. We now implement it as two-layer MLP, same as the original paper - + The attention unit of the ATOC network. We now implement it as two-layer MLP, same as the original paper. Interface: - __init__, forward + ``__init__``, ``forward`` .. note:: - "ATOC paper: We use two-layer MLP to implement the attention unit but it is also can be realized by RNN." - """ def __init__(self, thought_size: int, embedding_size: int) -> None: - r""" + """ Overview: - init the attention unit according to the size of input args - + Initialize the attention unit according to the size of input arguments. Arguments: - thought_size (:obj:`int`): the size of input thought - embedding_size (:obj:`int`): the size of hidden layers @@ -42,15 +38,19 @@ def __init__(self, thought_size: int, embedding_size: int) -> None: self._act2 = nn.Sigmoid() def forward(self, data: Union[Dict, torch.Tensor]) -> torch.Tensor: - r""" + """ Overview: - forward method take the thought of agents as input and output the prob of these agent\ - being initiator - + Take the thought of agents as input and generate the probability of these agent being initiator Arguments: - x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor - - ret (:obj:`torch.Tensor`): the output initiator prob - + - ret (:obj:`torch.Tensor`): the output initiator probability + Shapes: + - data['thought']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\ + B is batch_size and N is thought size + Examples: + >>> attention_unit = ATOCAttentionUnit(64, 64) + >>> thought = torch.randn(2, 3, 64) + >>> attention_unit(thought) """ x = data if isinstance(data, Dict): @@ -61,24 +61,21 @@ def forward(self, data: Union[Dict, torch.Tensor]) -> torch.Tensor: x = self._act1(x) x = self._fc3(x) x = self._act2(x) - # return {'initiator': x} return x.squeeze(-1) class ATOCCommunicationNet(nn.Module): - r""" + """ Overview: - atoc commnication net is a bi-direction LSTM, so it can integrate all the thoughts in the group - + This ATOC commnication net is a bi-direction LSTM, so it can integrate all the thoughts in the group. Interface: - __init__, forward + ``__init__``, ``forward`` """ def __init__(self, thought_size: int) -> None: - r""" + """ Overview: - init method of the communication network - + Initialize the communication network according to the size of input arguments. Arguments: - thought_size (:obj:`int`): the size of input thought @@ -93,32 +90,34 @@ def __init__(self, thought_size: int) -> None: self._bi_lstm = nn.LSTM(self._thought_size, self._comm_hidden_size, bidirectional=True) def forward(self, data: Union[Dict, torch.Tensor]): - r""" + """ Overview: - the forward method that integrate thoughts + The forward of ATOCCommunicationNet integrates thoughts in the group. Arguments: - x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor - out (:obj:`torch.Tensor`): the integrated thoughts Shapes: - data['thoughts']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\ B is batch_size and N is thought size + Examples: + >>> comm_net = ATOCCommunicationNet(64) + >>> thoughts = torch.randn(2, 3, 64) + >>> comm_net(thoughts) """ self._bi_lstm.flatten_parameters() x = data if isinstance(data, Dict): x = data['thoughts'] out, _ = self._bi_lstm(x) - # return {'thoughts': out} return out class ATOCActorNet(nn.Module): - r""" + """ Overview: - the overall ATOC actor network - + The actor network of ATOC. Interface: - __init__, forward + ``__init__``, ``forward`` .. note:: "ATOC paper: The neural networks use ReLU and batch normalization for some hidden layers." @@ -139,10 +138,9 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, ): - r""" + """ Overview: - the init method of atoc actor network - + Initialize the actor network of ATOC Arguments: - obs_shape(:obj:`Union[Tuple, int]`): the observation size - thought_size (:obj:`int`): the size of thoughts @@ -194,11 +192,9 @@ def __init__( self.comm_net = ATOCCommunicationNet(self._thought_size) def forward(self, obs: torch.Tensor) -> Dict: - r""" + """ Overview: - the forward method of actor network, take the input obs, and calculate the corresponding action, group, \ - initiator_prob, thoughts, etc... - + Take the input obs, and calculate the corresponding action, group, initiator_prob, thoughts, etc... Arguments: - obs (:obj:`Dict`): the input obs containing the observation Returns: @@ -207,6 +203,18 @@ def forward(self, obs: torch.Tensor) -> Dict: ReturnsKeys: - necessary: ``action`` - optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts`` + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size + - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size + - group (:obj:`torch.Tensor`): :math:`(B, A, A)` + - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` + - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` + - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` + - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` + Examples: + >>> actor_net = ATOCActorNet(64, 64, 64, 3) + >>> obs = torch.randn(2, 3, 64) + >>> actor_net(obs) """ assert len(obs.shape) == 3 self._cur_batch_size = obs.shape[0] @@ -238,6 +246,25 @@ def forward(self, obs: torch.Tensor) -> Dict: return {'action': action} def _get_initiate_group(self, current_thoughts): + """ + Overview: + Calculate the initiator probability, group and is_initiator + Arguments: + - current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts + Returns: + - init_prob (:obj:`torch.Tensor`): tesnor of initiator probability + - is_initiator (:obj:`torch.Tensor`): tensor of is initiator + - group (:obj:`torch.Tensor`): tensor of group + Shapes: + - current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size + - init_prob (:obj:`torch.Tensor`): :math:`(B, A)` + - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` + - group (:obj:`torch.Tensor`): :math:`(B, A, A)` + Examples: + >>> actor_net = ATOCActorNet(64, 64, 64, 3) + >>> current_thoughts = torch.randn(2, 3, 64) + >>> actor_net._get_initiate_group(current_thoughts) + """ if not self._communication: raise NotImplementedError init_prob = self.attention(current_thoughts) # B, A @@ -267,10 +294,25 @@ def _get_initiate_group(self, current_thoughts): def _get_new_thoughts(self, current_thoughts, group, is_initiator): """ + Overview: + Calculate the new thoughts according to current thoughts, group and is_initiator + Arguments: + - current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts + - group (:obj:`torch.Tensor`): tensor of group + - is_initiator (:obj:`torch.Tensor`): tensor of is initiator + Returns: + - new_thoughts (:obj:`torch.Tensor`): tensor of new thoughts Shapes: - current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size - group: (:obj:`torch.Tensor`): :math:`(B, A, A)` - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` + - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` + Examples: + >>> actor_net = ATOCActorNet(64, 64, 64, 3) + >>> current_thoughts = torch.randn(2, 3, 64) + >>> group = torch.randn(2, 3, 3) + >>> is_initiator = torch.randn(2, 3) + >>> actor_net._get_new_thoughts(current_thoughts, group, is_initiator) """ if not self._communication: raise NotImplementedError @@ -306,12 +348,13 @@ def _get_new_thoughts(self, current_thoughts, group, is_initiator): @MODEL_REGISTRY.register('atoc') class ATOC(nn.Module): - r""" + """ Overview: The QAC network of ATOC, a kind of extension of DDPG for MARL. - + Learning Attentional Communication for Multi-Agent Cooperation + https://arxiv.org/abs/1805.07733 Interface: - __init__, forward, compute_critic, compute_actor, optimize_actor_attention + ``__init__``, ``forward``, ``compute_critic``, ``compute_actor``, ``optimize_actor_attention`` """ mode = ['compute_actor', 'compute_critic', 'optimize_actor_attention'] @@ -330,10 +373,9 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, ) -> None: - r""" + """ Overview: - init the atoc QAC network - + Initialize the ATOC QAC network Arguments: - obs_shape(:obj:`Union[Tuple, int]`): the observation space shape - thought_size (:obj:`int`): the size of thoughts @@ -367,16 +409,33 @@ def __init__( ) def _compute_delta_q(self, obs: torch.Tensor, actor_outputs: Dict) -> torch.Tensor: - r""" + """ Overview: calculate the delta_q according to obs and actor_outputs - Arguments: - obs (:obj:`torch.Tensor`): the observations - actor_outputs (:obj:`dict`): the output of actors - delta_q (:obj:`Dict`): the calculated delta_q + Returns: + - delta_q (:obj:`Dict`): the calculated delta_q ArgumentsKeys: - necessary: ``new_thoughts``, ``old_thoughts``, ``group``, ``is_initiator`` + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size + - actor_outputs (:obj:`Dict`): the output of actor network, including ``action``, ``new_thoughts``, \ + ``old_thoughts``, ``group``, ``initiator_prob``, ``is_initiator`` + - action (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is action size + - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size + - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size + - group (:obj:`torch.Tensor`): :math:`(B, A, A)` + - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` + - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` + - delta_q (:obj:`torch.Tensor`): :math:`(B, A)` + Examples: + >>> net = ATOC(64, 64, 64, 3) + >>> obs = torch.randn(2, 3, 64) + >>> actor_outputs = net.compute_actor(obs) + >>> net._compute_delta_q(obs, actor_outputs) """ if not self._communication: raise NotImplementedError @@ -413,10 +472,9 @@ def _compute_delta_q(self, obs: torch.Tensor, actor_outputs: Dict) -> torch.Tens return curr_delta_q def compute_actor(self, obs: torch.Tensor, get_delta_q: bool = False) -> Dict[str, torch.Tensor]: - r''' + ''' Overview: compute the action according to inputs, call the _compute_delta_q function to compute delta_q - Arguments: - obs (:obj:`torch.Tensor`): observation - get_delta_q (:obj:`bool`) : whether need to get delta_q @@ -425,7 +483,19 @@ def compute_actor(self, obs: torch.Tensor, get_delta_q: bool = False) -> Dict[st ReturnsKeys: - necessary: ``action`` - optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts``, ``delta_q`` - + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size + - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size + - group (:obj:`torch.Tensor`): :math:`(B, A, A)` + - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` + - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` + - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` + - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` + - delta_q (:obj:`torch.Tensor`): :math:`(B, A)` + Examples: + >>> net = ATOC(64, 64, 64, 3) + >>> obs = torch.randn(2, 3, 64) + >>> net.compute_actor(obs) ''' outputs = self.actor(obs) if get_delta_q and self._communication: @@ -435,10 +505,25 @@ def compute_actor(self, obs: torch.Tensor, get_delta_q: bool = False) -> Dict[st def compute_critic(self, inputs: Dict) -> Dict: """ + Overview: + compute the q_value according to inputs + Arguments: + - inputs (:obj:`Dict`): the inputs contain the obs and action + Returns: + - outputs (:obj:`Dict`): the output of critic network ArgumentsKeys: - necessary: ``obs``, ``action`` ReturnsKeys: - necessary: ``q_value`` + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size + - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size + - q_value (:obj:`torch.Tensor`): :math:`(B, A)` + Examples: + >>> net = ATOC(64, 64, 64, 3) + >>> obs = torch.randn(2, 3, 64) + >>> action = torch.randn(2, 3, 64) + >>> net.compute_critic({'obs': obs, 'action': action}) """ obs, action = inputs['obs'], inputs['action'] if len(action.shape) == 2: # (B, A) -> (B, A, 1) @@ -448,14 +533,31 @@ def compute_critic(self, inputs: Dict) -> Dict: return {'q_value': x} def optimize_actor_attention(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - r""" + """ Overview: return the actor attention loss - Arguments: - inputs (:obj:`Dict`): the inputs contain the delta_q, initiator_prob, and is_initiator Returns - loss (:obj:`Dict`): the loss of actor attention unit + ArgumentsKeys: + - necessary: ``delta_q``, ``initiator_prob``, ``is_initiator`` + ReturnsKeys: + - necessary: ``loss`` + Shapes: + - delta_q (:obj:`torch.Tensor`): :math:`(B, A)` + - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)` + - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)` + - loss (:obj:`torch.Tensor`): :math:`(1)` + Examples: + >>> net = ATOC(64, 64, 64, 3) + >>> delta_q = torch.randn(2, 3) + >>> initiator_prob = torch.randn(2, 3) + >>> is_initiator = torch.randn(2, 3) + >>> net.optimize_actor_attention( + >>> {'delta_q': delta_q, + >>> 'initiator_prob': initiator_prob, + >>> 'is_initiator': is_initiator}) """ if not self._communication: raise NotImplementedError diff --git a/ding/model/template/bc.py b/ding/model/template/bc.py index ce58ca8c5f..5348c750a6 100644 --- a/ding/model/template/bc.py +++ b/ding/model/template/bc.py @@ -188,13 +188,20 @@ def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict: - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor. Returns: - output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space. - + ReturnsKeys: + - action (:obj:`torch.Tensor`): action output of actor network, \ + with shape :math:`(B, action_shape)`. + - logit (:obj:`List[torch.Tensor]`): reparameterized action output of actor network, \ + with shape :math:`(B, action_shape)`. + Shapes: + - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` + - action (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` + - logit (:obj:`List[torch.FloatTensor]`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` Examples (Regression): >>> model = ContinuousBC(32, 6, action_space='regression') >>> inputs = torch.randn(4, 32) >>> outputs = model(inputs) >>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6]) - Examples (Reparameterization): >>> model = ContinuousBC(32, 6, action_space='reparameterization') >>> inputs = torch.randn(4, 32) diff --git a/ding/model/template/bcq.py b/ding/model/template/bcq.py index 7b8d013e9e..0e72927a76 100755 --- a/ding/model/template/bcq.py +++ b/ding/model/template/bcq.py @@ -11,6 +11,16 @@ @MODEL_REGISTRY.register('bcq') class BCQ(nn.Module): + """ + Overview: + Model of BCQ (Batch-Constrained deep Q-learning). + Off-Policy Deep Reinforcement Learning without Exploration. + https://arxiv.org/abs/1812.02900 + Interface: + ``forward``, ``compute_actor``, ``compute_critic``, ``compute_vae``, ``compute_eval`` + Property: + ``mode`` + """ mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval'] @@ -93,7 +103,13 @@ def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - output (:obj:`Dict`): Output dict data, including action tensor. - + Examples: + >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} + >>> model = BCQ(32, 6) + >>> outputs = model(inputs, mode='compute_actor') + >>> outputs = model(inputs, mode='compute_critic') + >>> outputs = model(inputs, mode='compute_vae') + >>> outputs = model(inputs, mode='compute_eval') .. note:: For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. @@ -102,6 +118,21 @@ def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch return getattr(self, mode)(inputs) def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Overview: + Use critic network to compute q value. + Arguments: + - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. + Returns: + - outputs (:obj:`Dict`): Dict containing keywords ``q_value`` (:obj:`torch.Tensor`). + Shapes: + - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. + - outputs (:obj:`Dict`): :math:`(B, N)`. + Examples: + >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} + >>> model = BCQ(32, 6) + >>> outputs = model.compute_critic(inputs) + """ obs, action = inputs['obs'], inputs['action'] if len(action.shape) == 1: # (B, ) -> (B, 1) action = action.unsqueeze(1) @@ -110,6 +141,21 @@ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten return {'q_value': x} def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + """ + Overview: + Use actor network to compute action. + Arguments: + - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. + Returns: + - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). + Shapes: + - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. + - outputs (:obj:`Dict`): :math:`(B, N)`. + Examples: + >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} + >>> model = BCQ(32, 6) + >>> outputs = model.compute_actor(inputs) + """ input = torch.cat([inputs['obs'], inputs['action']], -1) x = self.actor(input) action = self.phi * 1 * torch.tanh(x) @@ -117,9 +163,41 @@ def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torc return {'action': action} def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Overview: + Use vae network to compute action. + Arguments: + - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. + Returns: + - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` (:obj:`torch.Tensor`), \ + ``prediction_residual`` (:obj:`torch.Tensor`), ``input`` (:obj:`torch.Tensor`), \ + ``mu`` (:obj:`torch.Tensor`), ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). + Shapes: + - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. + - outputs (:obj:`Dict`): :math:`(B, N)`. + Examples: + >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} + >>> model = BCQ(32, 6) + >>> outputs = model.compute_vae(inputs) + """ return self.vae.forward(inputs) def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Overview: + Use actor network to compute action. + Arguments: + - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. + Returns: + - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). + Shapes: + - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. + - outputs (:obj:`Dict`): :math:`(B, N)`. + Examples: + >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} + >>> model = BCQ(32, 6) + >>> outputs = model.compute_eval(inputs) + """ obs = inputs['obs'] obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0) z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5) diff --git a/ding/model/template/collaq.py b/ding/model/template/collaq.py index dd8db59986..9872d0684a 100644 --- a/ding/model/template/collaq.py +++ b/ding/model/template/collaq.py @@ -14,7 +14,7 @@ class CollaQMultiHeadAttention(nn.Module): Overview: The head of collaq attention module. Interface: - __init__, forward + ``__init__``, ``forward`` """ def __init__( @@ -69,8 +69,26 @@ def forward(self, q, k, v, mask=None): - q (:obj:`torch.nn.Sequential`): the transformer information q - k (:obj:`torch.nn.Sequential`): the transformer information k - v (:obj:`torch.nn.Sequential`): the transformer information v - Output: + Returns: - q (:obj:`torch.nn.Sequential`): the transformer output q + - residual (:obj:`torch.nn.Sequential`): the transformer output residual + Shapes: + - q (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ + N is the size of input q + - k (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ + N is the size of input k + - v (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ + N is the size of input v + - q (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ + N is the size of output q + - residual (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \ + N is the size of output residual + Examples: + >>> net = CollaQMultiHeadAttention(1, 2, 3, 4, 5, 6) + >>> q = torch.randn(1, 2, 2) + >>> k = torch.randn(1, 3, 3) + >>> v = torch.randn(1, 3, 3) + >>> q, residual = net(q, k, v) """ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head batch_size, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) @@ -104,7 +122,7 @@ class CollaQSMACAttentionModule(nn.Module): Collaq attention module. Used to get agent's attention observation. It includes agent's observation\ and agent's part of the observation information of the agent's concerned allies Interface: - __init__, _cut_obs, forward + ``__init__``, ``_cut_obs``, ``forward`` """ def __init__( @@ -140,9 +158,16 @@ def _cut_obs(self, obs: torch.Tensor): cut the observed information into self's observation and allay's observation Arguments: - obs (:obj:`torch.Tensor`): input each agent's observation - Return: + Returns: - self_features (:obj:`torch.Tensor`): output self agent's attention observation - ally_features (:obj:`torch.Tensor`): output ally agent's attention observation + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ + A is agent_num, N is obs_shape + - self_features (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ + A is agent_num, N is self_feature_range[1] - self_feature_range[0] + - ally_features (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ + A is agent_num, N is ally_feature_range[1] - ally_feature_range[0] """ # obs shape = (T, B, A, obs_shape) self_features = obs[:, :, :, self.self_feature_range[0]:self.self_feature_range[1]] @@ -155,8 +180,11 @@ def forward(self, inputs: torch.Tensor): forward computation to get agent's attention observation information Arguments: - obs (:obj:`torch.Tensor`): input each agent's observation - Return: + Returns: - obs (:obj:`torch.Tensor`): output agent's attention observation + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \ + A is agent_num, N is obs_shape """ # obs shape = (T, B ,A, obs_shape) obs = inputs @@ -182,9 +210,16 @@ def forward(self, inputs: torch.Tensor): class CollaQ(nn.Module): """ Overview: - CollaQ network + The network of CollaQ (Collaborative Q-learning) algorithm. + It includes two parts: q_network and q_alone_network. + The q_network is used to get the q_value of the agent's observation and \ + the agent's part of the observation information of the agent's concerned allies. + The q_alone_network is used to get the q_value of the agent's observation and \ + the agent's observation information without the agent's concerned allies. + Multi-Agent Collaboration via Reward Attribution Decomposition + https://arxiv.org/abs/2010.08531 Interface: - __init__, forward, _setup_global_encoder + ``__init__``, ``forward``, ``_setup_global_encoder`` """ def __init__( @@ -275,7 +310,8 @@ def __init__( def forward(self, data: dict, single_step: bool = True) -> dict: """ Overview: - forward computation graph of collaQ network + The forward method calculates the q_value of each agent and the total q_value of all agents. + The q_value of each agent is calculated by the q_network, and the total q_value is calculated by the mixer. Arguments: - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] - agent_state (:obj:`torch.Tensor`): each agent local state(obs) @@ -302,6 +338,32 @@ def forward(self, data: dict, single_step: bool = True) -> dict: - total_q (:obj:`torch.Tensor`): :math:`(T, B)` - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A + Examples: + >>> collaQ_model = CollaQ( + >>> agent_num=4, + >>> obs_shape=32, + >>> alone_obs_shape=24, + >>> global_obs_shape=32 * 4, + >>> action_shape=9, + >>> hidden_size_list=[128, 64], + >>> self_feature_range=[8, 10], + >>> ally_feature_range=[10, 16], + >>> attention_size=64, + >>> mixer=True, + >>> activation=torch.nn.Tanh() + >>> ) + >>> data={ + >>> 'obs': { + >>> 'agent_state': torch.randn(8, 4, 4, 32), + >>> 'agent_alone_state': torch.randn(8, 4, 4, 24), + >>> 'agent_alone_padding_state': torch.randn(8, 4, 4, 32), + >>> 'global_state': torch.randn(8, 4, 32 * 4), + >>> 'action_mask': torch.randint(0, 2, size=(8, 4, 4, 9)) + >>> }, + >>> 'prev_state': [[[None for _ in range(4)] for _ in range(3)] for _ in range(4)], + >>> 'action': torch.randint(0, 9, size=(8, 4, 4)) + >>> } + >>> output = collaQ_model(data, single_step=False) """ agent_state, agent_alone_state = data['obs']['agent_state'], data['obs']['agent_alone_state'] agent_alone_padding_state = data['obs']['agent_alone_padding_state'] @@ -426,7 +488,7 @@ def _setup_global_encoder(self, global_obs_shape: int, embedding_size: int) -> t Arguments: - global_obs_shape (:obj:`int`): the dimension of global observation state - embedding_size (:obj:`int`): the dimension of state emdedding - Return: + Returns: - outputs (:obj:`torch.nn.Module`): Global observation encoding network """ return MLP(global_obs_shape, embedding_size, embedding_size, 2, activation=self._act) diff --git a/ding/model/template/coma.py b/ding/model/template/coma.py index c120dfd78d..02eb286e84 100644 --- a/ding/model/template/coma.py +++ b/ding/model/template/coma.py @@ -11,9 +11,9 @@ class COMAActorNetwork(nn.Module): """ Overview: - Decentralized actor network in COMA + Decentralized actor network in COMA algorithm. Interface: - __init__, forward + ``__init__``, ``forward`` """ def __init__( @@ -24,7 +24,7 @@ def __init__( ): """ Overview: - initialize COMA actor network + Initialize COMA actor network Arguments: - obs_shape (:obj:`int`): the dimension of each agent's observation state - action_shape (:obj:`int`): the dimension of action shape @@ -35,10 +35,30 @@ def __init__( def forward(self, inputs: Dict) -> Dict: """ + Overview: + The forward computation graph of COMA actor network + Arguments: + - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state'] + - agent_state (:obj:`torch.Tensor`): each agent local state(obs) + - action_mask (:obj:`torch.Tensor`): the masked action + - prev_state (:obj:`torch.Tensor`): the previous hidden state + Returns: + - output (:obj:`dict`): output data dict with keys ['logit', 'next_state', 'action_mask'] ArgumentsKeys: - necessary: ``obs`` { ``agent_state``, ``action_mask`` }, ``prev_state`` ReturnsKeys: - necessary: ``logit``, ``next_state``, ``action_mask`` + Examples: + >>> T, B, A, N = 4, 8, 3, 32 + >>> embedding_dim = 64 + >>> action_dim = 6 + >>> data = torch.randn(T, B, A, N) + >>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim]) + >>> prev_state = [[None for _ in range(A)] for _ in range(B)] + >>> for t in range(T): + >>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state} + >>> outputs = model(inputs) + >>> logit, prev_state = outputs['logit'], outputs['next_state'] """ agent_state = inputs['obs']['agent_state'] prev_state = inputs['prev_state'] @@ -62,9 +82,9 @@ def forward(self, inputs: Dict) -> Dict: class COMACriticNetwork(nn.Module): """ Overview: - Centralized critic network in COMA + Centralized critic network in COMA algorithm. Interface: - __init__, forward + ``__init__``, ``forward`` """ def __init__( @@ -80,6 +100,14 @@ def __init__( - input_size (:obj:`int`): the size of input global observation - action_shape (:obj:`int`): the dimension of action shape - hidden_size_list (:obj:`list`): the list of hidden size, default to 128 + Returns: + - output (:obj:`dict`): output data dict with keys ['q_value'] + Shapes: + - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` + - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` + - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` + - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` + - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` """ super(COMACriticNetwork, self).__init__() self.action_shape = action_shape @@ -101,6 +129,19 @@ def forward(self, data: Dict) -> Dict: - necessary: ``obs`` { ``agent_state``, ``global_state`` }, ``action``, ``prev_state`` ReturnsKeys: - necessary: ``q_value`` + Examples: + >>> agent_num, bs, T = 4, 3, 8 + >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 + >>> coma_model = COMACriticNetwork( + >>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim) + >>> data = { + >>> 'obs': { + >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), + >>> 'global_state': torch.randn(T, bs, global_obs_dim), + >>> }, + >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), + >>> } + >>> output = coma_model(data) """ x = self._preprocess_data(data) q = self.mlp(x) @@ -145,8 +186,13 @@ def _preprocess_data(self, data: Dict) -> torch.Tensor: class COMA(nn.Module): """ Overview: - COMA network is QAC-type actor-critic. + The network of COMA algorithm, which is QAC-type actor-critic. + Interface: + ``__init__``, ``forward`` + Properties: + - mode (:obj:`list`): The list of forward mode, including ``compute_actor`` and ``compute_critic`` """ + mode = ['compute_actor', 'compute_critic'] def __init__( @@ -174,12 +220,53 @@ def __init__( def forward(self, inputs: Dict, mode: str) -> Dict: """ + Overview: + forward computation graph of COMA network + Arguments: + - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] + - agent_state (:obj:`torch.Tensor`): each agent local state(obs) + - global_state (:obj:`torch.Tensor`): global state(obs) + - action (:obj:`torch.Tensor`): the masked action ArgumentsKeys: - necessary: ``obs`` { ``agent_state``, ``global_state``, ``action_mask`` }, ``action``, ``prev_state`` ReturnsKeys: - necessary: - compute_critic: ``q_value`` - compute_actor: ``logit``, ``next_state``, ``action_mask`` + Shapes: + - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` + - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` + - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` + - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` + - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` + - q_value (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` + Examples: + >>> agent_num, bs, T = 4, 3, 8 + >>> agent_num, bs, T = 4, 3, 8 + >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 + >>> coma_model = COMA( + >>> agent_num=agent_num, + >>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )), + >>> action_shape=action_dim, + >>> actor_hidden_size_list=[128, 64], + >>> ) + >>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)] + >>> data = { + >>> 'obs': { + >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), + >>> 'action_mask': None, + >>> }, + >>> 'prev_state': prev_state, + >>> } + >>> output = coma_model(data, mode='compute_actor') + >>> data= { + >>> 'obs': { + >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), + >>> 'global_state': torch.randn(T, bs, global_obs_dim), + >>> }, + >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), + >>> } + >>> output = coma_model(data, mode='compute_critic') """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) if mode == 'compute_actor': diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index 6aca4041da..3d35497383 100644 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -71,7 +71,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: - x (:obj:`torch.Tensor`): The input tensor. Returns: - out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. - Examples: >>> inputs = torch.randn(2, 4, 64) >>> model = MaskedCausalAttention(64, 5, 4, 0.1) @@ -142,7 +141,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: - x (:obj:`torch.Tensor`): The input tensor. Returns: - output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. - Examples: >>> inputs = torch.randn(2, 4, 64) >>> model = Block(64, 5, 4, 0.1) @@ -260,7 +258,6 @@ def forward( Returns: - output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \ they are correspondingly the predicted states, predicted actions and predicted return-to-go. - Examples: >>> B, T = 4, 6 >>> state_dim = 3 @@ -274,21 +271,16 @@ def forward( n_heads=2,\ drop_p=0.1,\ ) - >>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T >>> states = torch.randn([B, T, state_dim]) # B x T x state_dim - >>> actions = torch.randint(0, act_dim, [B, T, 1]) >>> action_target = torch.randint(0, act_dim, [B, T, 1]) >>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float() - >>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T >>> actions = actions.squeeze(-1) - >>> state_preds, action_preds, return_preds = DT_model.forward(\ timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\ ) - >>> assert state_preds.shape == torch.Size([B, T, state_dim]) >>> assert return_preds.shape == torch.Size([B, T, 1]) >>> assert action_preds.shape == torch.Size([B, T, act_dim]) diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index f8b48f3061..f56b84d1b6 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -9,6 +9,194 @@ from ding.torch_utils.network.diffusion import extract, cosine_beta_schedule, apply_conditioning, \ DiffusionUNet1d, TemporalValue +import torch +import torch.nn as nn +import einops +from einops.layers.torch import Rearrange +import pdb + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops +from einops.layers.torch import Rearrange +import pdb + + +#-----------------------------------------------------------------------------# +#---------------------------------- modules ----------------------------------# +#-----------------------------------------------------------------------------# + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + +class Conv1dBlock(nn.Module): + ''' + Conv1d --> GroupNorm --> Mish + ''' + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + Rearrange('batch channels horizon -> batch channels 1 horizon'), + nn.GroupNorm(n_groups, out_channels), + Rearrange('batch channels 1 horizon -> batch channels horizon'), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + +#-----------------------------------------------------------------------------# +#--------------------------------- attention ---------------------------------# +#-----------------------------------------------------------------------------# + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + +class LayerNorm(nn.Module): + def __init__(self, dim, eps = 1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.g + self.b + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv1d(hidden_dim, dim, 1) + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: einops.rearrange(t, 'b (h c) d -> b h c d', h=self.heads), qkv) + q = q * self.scale + + k = k.softmax(dim = -1) + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = einops.rearrange(out, 'b h c d -> b (h c) d') + return self.to_out(out) + +#-----------------------------------------------------------------------------# +#---------------------------------- sampling ---------------------------------# +#-----------------------------------------------------------------------------# + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas_clipped = np.clip(betas, a_min=0, a_max=0.999) + return torch.tensor(betas_clipped, dtype=dtype) + +def apply_conditioning(x, conditions, action_dim): + for t, val in conditions.items(): + x[:, t, action_dim:] = val.clone() + return x + + +#-----------------------------------------------------------------------------# +#---------------------------------- losses -----------------------------------# +#-----------------------------------------------------------------------------# + + + +class ResidualTemporalBlock(nn.Module): + + def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): + super().__init__() + + self.blocks = nn.ModuleList([ + Conv1dBlock(inp_channels, out_channels, kernel_size), + Conv1dBlock(out_channels, out_channels, kernel_size), + ]) + + self.time_mlp = nn.Sequential( + nn.Mish(), + nn.Linear(embed_dim, out_channels), + Rearrange('batch t -> batch t 1'), + ) + + self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ + if inp_channels != out_channels else nn.Identity() + + def forward(self, x, t): + ''' + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + returns: + out : [ batch_size x out_channels x horizon ] + ''' + out = self.blocks[0](x) + self.time_mlp(t) + out = self.blocks[1](out) + return out + self.residual_conv(x) + Sample = namedtuple('Sample', 'trajectories values chains') @@ -293,6 +481,106 @@ def forward(self, cond, *args, **kwargs): return self.conditional_sample(cond, *args, **kwargs) +class TemporalUnet(nn.Module): + + def __init__( + self, + horizon, + transition_dim, + cond_dim, + dim=32, + dim_mults=(1, 2, 4, 8), + attention=False, + ): + super().__init__() + + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + print(f'[ models/temporal ] Channel dimensions: {in_out}') + + time_dim = dim + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim), + ) + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + print(in_out) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), + Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(), + Downsample1d(dim_out) if not is_last else nn.Identity() + ])) + + if not is_last: + horizon = horizon // 2 + + mid_dim = dims[-1] + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) + self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity() + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), + Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(), + Upsample1d(dim_in) if not is_last else nn.Identity() + ])) + + if not is_last: + horizon = horizon * 2 + + self.final_conv = nn.Sequential( + Conv1dBlock(dim, dim, kernel_size=5), + nn.Conv1d(dim, transition_dim, 1), + ) + + def forward(self, x, cond, time): + ''' + x : [ batch x horizon x transition ] + ''' + + x = einops.rearrange(x, 'b h t -> b t h') + + t = self.time_mlp(time) + h = [] + + for resnet, resnet2, attn, downsample in self.downs: + x = resnet(x, t) + x = resnet2(x, t) + x = attn(x) + h.append(x) + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + + for resnet, resnet2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = resnet(x, t) + x = resnet2(x, t) + x = attn(x) + x = upsample(x) + + x = self.final_conv(x) + + x = einops.rearrange(x, 'b t h -> b h t') + return x + + class ValueDiffusion(GaussianDiffusion): """ Overview: diff --git a/ding/model/template/ebm.py b/ding/model/template/ebm.py index fe1c073b17..4b91fd1b6d 100644 --- a/ding/model/template/ebm.py +++ b/ding/model/template/ebm.py @@ -15,10 +15,17 @@ from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY from ding.torch_utils import unsqueeze_repeat from ding.model.wrapper import IModelWrapper -from ..common import RegressionHead +from ding.model.common import RegressionHead def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict): + """ + Overview: + Create stochastic optimizer. + Arguments: + - device (:obj:`str`): Device. + - stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config. + """ return STOCHASTIC_OPTIMIZER_REGISTRY.build( stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config ) @@ -45,20 +52,34 @@ def wrapper(*args, **kwargs): class StochasticOptimizer(ABC): + """ + Overview: + Base class for stochastic optimizers. + Interface: + ``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer`` + """ def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: - Helper method for drawing action samples from the uniform random distribution \ + Drawing action samples from the uniform random distribution \ and tiling observations to the same shape as action samples. - Arguments: - - obs (:obj:`torch.Tensor`): Observation of shape (B, O). - - num_samples (:obj:`int`): The number of negative samples (N). - + - obs (:obj:`torch.Tensor`): Observation. + - num_samples (:obj:`int`): The number of negative samples. Returns: - - tiled_obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). - - action (:obj:`torch.Tensor`): Action of shape (B, N, A). + - tiled_obs (:obj:`torch.Tensor`): Observations tiled. + - action (:obj:`torch.Tensor`): Action sampled. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - num_samples (:obj:`int`): :math:`N`. + - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. + - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. + Examples: + >>> obs = torch.randn(2, 4) + >>> opt = StochasticOptimizer() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) + >>> tiled_obs, action = opt._sample(obs, 8) """ size = (obs.shape[0], num_samples, self.action_bounds.shape[1]) low, high = self.action_bounds[0, :], self.action_bounds[1, :] @@ -72,13 +93,22 @@ def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm """ Overview: Return one action for each batch with highest probability (lowest energy). - Arguments: - - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). - - action_samples (:obj:`torch.Tensor`): Action of shape (B, N, A). - + - obs (:obj:`torch.Tensor`): Observation. + - action_samples (:obj:`torch.Tensor`): Action from uniform distributions. Returns: - - best_action_samples (:obj:`torch.Tensor`): Action of shape (B, A). + - best_action_samples (:obj:`torch.Tensor`): Best action. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. + - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. + Examples: + >>> obs = torch.randn(2, 4) + >>> action_samples = torch.randn(2, 8, 5) + >>> ebm = EBM(4, 5) + >>> opt = StochasticOptimizer() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) + >>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm) """ # (B, N) energies = ebm.forward(obs, action_samples) @@ -91,10 +121,17 @@ def set_action_bounds(self, action_bounds: np.ndarray): """ Overview: Set action bounds calculated from the dataset statistics. - Arguments: - action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \ where action_bounds[0] is lower bound and action_bounds[1] is upper bound. + Returns: + - action_bounds (:obj:`torch.Tensor`): Action bounds. + Shapes: + - action_bounds (:obj:`np.ndarray`): :math:`(2, A)`. + - action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`. + Examples: + >>> opt = StochasticOptimizer() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) """ self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device) @@ -103,14 +140,17 @@ def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch """ Overview: Create tiled observations and sample counter-negatives for InfoNCE loss. - Arguments: - - obs (:obj:`torch.Tensor`): Observation of shape (B, O). + - obs (:obj:`torch.Tensor`): Observations. - ebm (:obj:`torch.nn.Module`): Energy based model. - Returns: - - tiled_obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). - - action (:obj:`torch.Tensor`): Action of shape (B, N, A). + - tiled_obs (:obj:`torch.Tensor`): Tiled observations. + - action (:obj:`torch.Tensor`): Actions. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. + - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. .. note:: In the case of derivative-free optimization, this function will simply call _sample. """ @@ -122,16 +162,27 @@ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: Overview: Optimize for the best action conditioned on the current observation. Arguments: - - obs (:obj:`torch.Tensor`): Observation of shape (B, O). + - obs (:obj:`torch.Tensor`): Observations. - ebm (:obj:`torch.nn.Module`): Energy based model. Returns: - - best_action_samples (:obj:`torch.Tensor`): Action of shape (B, A). + - best_action_samples (:obj:`torch.Tensor`): Best actions. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. """ raise NotImplementedError @STOCHASTIC_OPTIMIZER_REGISTRY.register('dfo') class DFO(StochasticOptimizer): + """ + Overview: + Derivative-Free Optimizer in paper Implicit Behavioral Cloning. + https://arxiv.org/abs/2109.00137 + Interface: + ``init``, ``sample``, ``infer`` + """ def __init__( self, @@ -142,6 +193,17 @@ def __init__( inference_samples: int = 16384, device: str = 'cpu', ): + """ + Overview: + Initialize the Derivative-Free Optimizer + Arguments: + - noise_scale (:obj:`float`): Initial noise scale. + - noise_shrink (:obj:`float`): Noise scale shrink rate. + - iters (:obj:`int`): Number of iterations. + - train_samples (:obj:`int`): Number of samples for training. + - inference_samples (:obj:`int`): Number of samples for inference. + - device (:obj:`str`): Device. + """ self.action_bounds = None self.noise_scale = noise_scale self.noise_shrink = noise_shrink @@ -151,10 +213,51 @@ def __init__( self.device = device def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Drawing action samples from the uniform random distribution \ + and tiling observations to the same shape as action samples. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - ebm (:obj:`torch.nn.Module`): Energy based model. + Returns: + - tiled_obs (:obj:`torch.Tensor`): Tiled observation. + - action_samples (:obj:`torch.Tensor`): Action samples. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. + - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. + Examples: + >>> obs = torch.randn(2, 4) + >>> ebm = EBM(4, 5) + >>> opt = DFO() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) + >>> tiled_obs, action_samples = opt.sample(obs, ebm) + """ return self._sample(obs, self.train_samples) @torch.no_grad() def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: + """ + Overview: + Optimize for the best action conditioned on the current observation. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - ebm (:obj:`torch.nn.Module`): Energy based model. + Returns: + - best_action_samples (:obj:`torch.Tensor`): Actions. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. + Examples: + >>> obs = torch.randn(2, 4) + >>> ebm = EBM(4, 5) + >>> opt = DFO() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) + >>> best_action_samples = opt.infer(obs, ebm) + """ noise_scale = self.noise_scale # (B, N, O), (B, N, A) @@ -181,6 +284,13 @@ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: @STOCHASTIC_OPTIMIZER_REGISTRY.register('ardfo') class AutoRegressiveDFO(DFO): + """ + Overview: + AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning. + https://arxiv.org/abs/2109.00137 + Interface: + ``__init__``, ``infer`` + """ def __init__( self, @@ -191,10 +301,40 @@ def __init__( inference_samples: int = 4096, device: str = 'cpu', ): + """ + Overview: + Initialize the AutoRegressive Derivative-Free Optimizer + Arguments: + - noise_scale (:obj:`float`): Initial noise scale. + - noise_shrink (:obj:`float`): Noise scale shrink rate. + - iters (:obj:`int`): Number of iterations. + - train_samples (:obj:`int`): Number of samples for training. + - inference_samples (:obj:`int`): Number of samples for inference. + - device (:obj:`str`): Device. + """ super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device) @torch.no_grad() def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: + """ + Overview: + Optimize for the best action conditioned on the current observation. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - ebm (:obj:`torch.nn.Module`): Energy based model. + Returns: + - best_action_samples (:obj:`torch.Tensor`): Actions. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. + Examples: + >>> obs = torch.randn(2, 4) + >>> ebm = EBM(4, 5) + >>> opt = AutoRegressiveDFO() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) + >>> best_action_samples = opt.infer(obs, ebm) + """ noise_scale = self.noise_scale # (B, N, O), (B, N, A) @@ -230,38 +370,91 @@ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: @STOCHASTIC_OPTIMIZER_REGISTRY.register('mcmc') class MCMC(StochasticOptimizer): + """ + Overview: + MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning. + https://arxiv.org/abs/2109.00137 + Interface: + ``__init__``, ``sample``, ``infer``, ``grad_penalty`` + """ class BaseScheduler(ABC): + """ + Overview: + Base class for learning rate scheduler. + Interface: + ``get_rate`` + """ @abstractmethod def get_rate(self, index): + """ + Overview: + Abstract method for getting learning rate. + """ raise NotImplementedError class ExponentialScheduler: - """Exponential learning rate schedule for Langevin sampler.""" + """ + Overview: + Exponential learning rate schedule for Langevin sampler. + Interface: + ``__init__``, ``get_rate`` + """ def __init__(self, init, decay): + """ + Overview: + Initialize the ExponentialScheduler. + Arguments: + - init (:obj:`float`): Initial learning rate. + - decay (:obj:`float`): Decay rate. + """ self._decay = decay self._latest_lr = init def get_rate(self, index): - """Get learning rate. Assumes calling sequentially.""" + """ + Overview: + Get learning rate. Assumes calling sequentially. + Arguments: + - index (:obj:`int`): Current iteration. + """ del index lr = self._latest_lr self._latest_lr *= self._decay return lr class PolynomialScheduler: - """Polynomial learning rate schedule for Langevin sampler.""" + """ + Overview: + Polynomial learning rate schedule for Langevin sampler. + Interface: + ``__init__``, ``get_rate`` + """ def __init__(self, init, final, power, num_steps): + """ + Overview: + Initialize the PolynomialScheduler. + Arguments: + - init (:obj:`float`): Initial learning rate. + - final (:obj:`float`): Final learning rate. + - power (:obj:`float`): Power of polynomial. + - num_steps (:obj:`int`): Number of steps. + """ self._init = init self._final = final self._power = power self._num_steps = num_steps def get_rate(self, index): - """Get learning rate for index.""" + """ + Overview: + Get learning rate for index. + Arguments: + - index (:obj:`int`): Current iteration. + """ if index == -1: return self._init return ( @@ -298,6 +491,26 @@ def __init__( grad_loss_weight: float = 1.0, **kwargs, ): + """ + Overview: + Initialize the MCMC. + Arguments: + - iters (:obj:`int`): Number of iterations. + - use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler. + - train_samples (:obj:`int`): Number of samples for training. + - inference_samples (:obj:`int`): Number of samples for inference. + - stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler. + - optimize_again (:obj:`bool`): Whether to run a second optimization. + - again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization. + - device (:obj:`str`): Device. + - noise_scale (:obj:`float`): Initial noise scale. + - grad_clip (:obj:`float`): Gradient clip. + - delta_action_clip (:obj:`float`): Action clip. + - add_grad_penalty (:obj:`bool`): Whether to add gradient penalty. + - grad_norm_type (:obj:`str`): Gradient norm type. + - grad_margin (:obj:`float`): Gradient margin. + - grad_loss_weight (:obj:`float`): Gradient loss weight. + """ self.iters = iters self.use_langevin_negative_samples = use_langevin_negative_samples self.train_samples = train_samples @@ -323,9 +536,20 @@ def _gradient_wrt_act( create_graph: bool = False, ) -> torch.Tensor: """ - Calculate gradient w.r.t action. - obs: (B, N, O), action: (B, N, A). - return: (B, N, A). + Overview: + Calculate gradient w.r.t action. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - action (:obj:`torch.Tensor`): Actions. + - ebm (:obj:`torch.nn.Module`): Energy based model. + - create_graph (:obj:`bool`): Whether to create graph. + Returns: + - grad (:obj:`torch.Tensor`): Gradient w.r.t action. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. + - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - grad (:obj:`torch.Tensor`): :math:`(B, N, A)`. """ action.requires_grad_(True) energy = ebm.forward(obs, action).sum() @@ -337,9 +561,19 @@ def _gradient_wrt_act( def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor: """ - Calculate gradient penalty. - obs: (B, N+1, O), action: (B, N+1, A). - return: loss. + Overview: + Calculate gradient penalty. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - action (:obj:`torch.Tensor`): Actions. + - ebm (:obj:`torch.nn.Module`): Energy based model. + Returns: + - loss (:obj:`torch.Tensor`): Gradient penalty. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`. + - action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`. + - loss (:obj:`torch.Tensor`): :math:`(B, )`. """ if not self.add_grad_penalty: return 0. @@ -371,9 +605,20 @@ def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor: @no_ebm_grad() def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor: """ - Run one langevin MCMC step. - obs: (B, N, O), action: (B, N, A) - return: (B, N, A). + Overview: + Run one langevin MCMC step. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - action (:obj:`torch.Tensor`): Actions. + - stepsize (:obj:`float`): Step size. + - ebm (:obj:`torch.nn.Module`): Energy based model. + Returns: + - action (:obj:`torch.Tensor`): Actions. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. + - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. + - stepsize (:obj:`float`): :math:`(B, )`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. """ l_lambda = 1.0 de_dact = MCMC._gradient_wrt_act(obs, action, ebm) @@ -402,9 +647,19 @@ def _langevin_action_given_obs( scheduler: BaseScheduler = None ) -> torch.Tensor: """ - Run langevin MCMC for `self.iters` steps. - obs: (B, N, O), action: (B, N, A) - return: (B, N, A) + Overview: + Run langevin MCMC for `self.iters` steps. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - action (:obj:`torch.Tensor`): Actions. + - ebm (:obj:`torch.nn.Module`): Energy based model. + - scheduler (:obj:`BaseScheduler`): Learning rate scheduler. + Returns: + - action (:obj:`torch.Tensor`): Actions. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. + - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. """ if not scheduler: self.stepsize_scheduler['num_steps'] = self.iters @@ -417,6 +672,27 @@ def _langevin_action_given_obs( @no_ebm_grad() def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Create tiled observations and sample counter-negatives for InfoNCE loss. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - ebm (:obj:`torch.nn.Module`): Energy based model. + Returns: + - tiled_obs (:obj:`torch.Tensor`): Tiled observations. + - action_samples (:obj:`torch.Tensor`): Action samples. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. + - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. + Examples: + >>> obs = torch.randn(2, 4) + >>> ebm = EBM(4, 5) + >>> opt = MCMC() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) + >>> tiled_obs, action_samples = opt.sample(obs, ebm) + """ obs, uniform_action_samples = self._sample(obs, self.train_samples) if not self.use_langevin_negative_samples: return obs, uniform_action_samples @@ -425,6 +701,25 @@ def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch @no_ebm_grad() def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: + """ + Overview: + Optimize for the best action conditioned on the current observation. + Arguments: + - obs (:obj:`torch.Tensor`): Observations. + - ebm (:obj:`torch.nn.Module`): Energy based model. + Returns: + - best_action_samples (:obj:`torch.Tensor`): Actions. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, O)`. + - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. + - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. + Examples: + >>> obs = torch.randn(2, 4) + >>> ebm = EBM(4, 5) + >>> opt = MCMC() + >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) + >>> best_action_samples = opt.infer(obs, ebm) + """ # (B, N, O), (B, N, A) obs, uniform_action_samples = self._sample(obs, self.inference_samples) action_samples = self._langevin_action_given_obs( @@ -449,6 +744,12 @@ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: @MODEL_REGISTRY.register('ebm') class EBM(nn.Module): + """ + Overview: + Energy based model. + Interface: + ``__init__``, ``forward`` + """ def __init__( self, @@ -458,6 +759,15 @@ def __init__( hidden_layer_num: int = 4, **kwargs, ): + """ + Overview: + Initialize the EBM. + Arguments: + - obs_shape (:obj:`int`): Observation shape. + - action_shape (:obj:`int`): Action shape. + - hidden_size (:obj:`int`): Hidden size. + - hidden_layer_num (:obj:`int`): Number of hidden layers. + """ super().__init__() input_size = obs_shape + action_shape self.net = nn.Sequential( @@ -471,9 +781,20 @@ def __init__( ) def forward(self, obs, action): - # obs: (B, N, O) - # action: (B, N, A) - # return: (B, N) + """ + Overview: + Forward computation graph of EBM. + Arguments: + - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). + - action (:obj:`torch.Tensor`): Action of shape (B, N, A). + Returns: + - pred (:obj:`torch.Tensor`): Energy of shape (B, N). + Examples: + >>> obs = torch.randn(2, 3, 4) + >>> action = torch.randn(2, 3, 5) + >>> ebm = EBM(4, 5) + >>> pred = ebm(obs, action) + """ x = torch.cat([obs, action], -1) x = self.net(x) return x['pred'] @@ -481,6 +802,12 @@ def forward(self, obs, action): @MODEL_REGISTRY.register('arebm') class AutoregressiveEBM(nn.Module): + """ + Overview: + Autoregressive energy based model. + Interface: + ``__init__``, ``forward`` + """ def __init__( self, @@ -488,21 +815,37 @@ def __init__( action_shape: int, hidden_size: int = 512, hidden_layer_num: int = 4, - **kwargs, ): + """ + Overview: + Initialize the AutoregressiveEBM. + Arguments: + - obs_shape (:obj:`int`): Observation shape. + - action_shape (:obj:`int`): Action shape. + - hidden_size (:obj:`int`): Hidden size. + - hidden_layer_num (:obj:`int`): Number of hidden layers. + """ super().__init__() self.ebm_list = nn.ModuleList() for i in range(action_shape): self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num)) def forward(self, obs, action): - # obs: (B, N, O) - # action: (B, N, A) - # return: (B, N, A) - - # (B, N) + """ + Overview: + Forward computation graph of AutoregressiveEBM. + Arguments: + - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). + - action (:obj:`torch.Tensor`): Action of shape (B, N, A). + Returns: + - pred (:obj:`torch.Tensor`): Energy of shape (B, N, A). + Examples: + >>> obs = torch.randn(2, 3, 4) + >>> action = torch.randn(2, 3, 5) + >>> arebm = AutoregressiveEBM(4, 5) + >>> pred = arebm(obs, action) + """ output_list = [] for i, ebm in enumerate(self.ebm_list): output_list.append(ebm(obs, action[..., :i + 1])) - # (B, N, A) return torch.stack(output_list, axis=-1) diff --git a/ding/model/template/maqac.py b/ding/model/template/maqac.py index 69c6d4cee0..ba74b97573 100644 --- a/ding/model/template/maqac.py +++ b/ding/model/template/maqac.py @@ -13,7 +13,10 @@ class DiscreteMAQAC(nn.Module): """ Overview: - The discrete action Multi-Agent Q-value Actor-CritiC (MAQAC) model. + The neural network and computation graph of algorithms related to discrete action Multi-Agent Q-value \ + Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \ + critic is a MLP network. The actor network is used to predict the action probability distribution, and the \ + critic network is used to predict the Q value of the state-action pair. Interfaces: ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` """ @@ -34,7 +37,7 @@ def __init__( ) -> None: """ Overview: - Init the DiscreteMAQAC Model according to arguments. + Initialize the DiscreteMAQAC Model according to arguments. Arguments: - agent_obs_shape (:obj:`Union[int, SequenceType]`): Agent's observation's space. - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space. @@ -91,93 +94,189 @@ def __init__( ) def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: - r""" + """ Overview: - Use observation and action tensor to predict output. - Parameter updates with QAC's MLPs forward setup. + Use observation tensor to predict output, with ``'compute_actor'`` or ``'compute_critic'`` mode. Arguments: - Forward with ``'compute_actor'``: - - inputs (:obj:`torch.Tensor`): - The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. - Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. - Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys: - - ``obs``, ``action`` encoded tensors. - - mode (:obj:`str`): Name of the forward mode. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ + with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ + with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ + with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. Returns: - - outputs (:obj:`Dict`): Outputs of network forward. - Forward with ``'compute_actor'``, Necessary Keys (either): - - action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``. - - logit (:obj:`torch.Tensor`): Action's probabilities. - Forward with ``'compute_critic'``, Necessary Keys: - - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. - Actor Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` - - action (:obj:`torch.Tensor`): :math:`(B, N0)` - - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. - Critic Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``global_obs_shape`` - - logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, whose \ + key-values vary in different forward modes. + Forward with ``'compute_actor'``, Necessary Keys (either): + - logit (:obj:`torch.Tensor`): Action's probabilities. + - action_mask (:obj:`torch.Tensor`): Action mask tensor with same size as ``action_shape``. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): Q value tensor is the shape of :math:`(B, A, N2)`, where B is batch size \ + and A is agent num. N2 corresponds to ``action_shape``. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N2)`, where B is batch size and \ + A is agent num. N2 corresponds to ``action_shape``. + Shapes: + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, whose \ + key-values vary in different forward modes. + Forward with ``'compute_actor'``, Necessary Keys (either): + - logit (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - action_mask (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N2)`, where B is batch size and \ + A is agent num. N2 corresponds to ``action_shape``. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + Examples: + >>> B = 32 + >>> agent_obs_shape = 216 + >>> global_obs_shape = 264 + >>> agent_num = 8 + >>> action_shape = 14 + >>> data = { + >>> 'obs': { + >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), + >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), + >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) + >>> } + >>> } + >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) + >>> logit = model(data, mode='compute_actor')['logit'] + >>> value = model(data, mode='compute_critic')['q_value'] """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_actor(self, inputs: Dict) -> Dict: - r""" + """ Overview: - Use encoded embedding tensor to predict output. - Execute parameter updates with ``'compute_actor'`` mode - Use encoded embedding tensor to predict output. + Use observation tensor to predict action logits. Arguments: - - inputs (:obj:`torch.Tensor`): - The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. - ``hidden_size = actor_head_hidden_size`` - - mode (:obj:`str`): Name of the forward mode. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ + with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ + with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ + with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. Returns: - - outputs (:obj:`Dict`): Outputs of forward pass encoder and head. - ReturnsKeys (either): - - action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``. - - logit (:obj:`torch.Tensor`): - Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``. + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, whose \ + key-values vary in different forward modes. + - logit (:obj:`torch.Tensor`): Action's probabilities. + - action_mask (:obj:`torch.Tensor`): Action mask tensor with same size as ``action_shape``. Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` - - action (:obj:`torch.Tensor`): :math:`(B, N0)` - - logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`. - - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, whose \ + key-values vary in different forward modes. + - logit (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - action_mask (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. Examples: - >>> # Regression mode - >>> model = DiscreteQAC(64, 64, 'regression') - >>> inputs = torch.randn(4, 64) - >>> actor_outputs = model(inputs,'compute_actor') - >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) - >>> # Reparameterization Mode - >>> model = DiscreteQAC(64, 64, 'reparameterization') - >>> inputs = torch.randn(4, 64) - >>> actor_outputs = model(inputs,'compute_actor') - >>> actor_outputs['logit'][0].shape # mu - >>> torch.Size([4, 64]) - >>> actor_outputs['logit'][1].shape # sigma - >>> torch.Size([4, 64]) + >>> B = 32 + >>> agent_obs_shape = 216 + >>> global_obs_shape = 264 + >>> agent_num = 8 + >>> action_shape = 14 + >>> data = { + >>> 'obs': { + >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), + >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), + >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) + >>> } + >>> } + >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) + >>> logit = model.compute_actor(data)['logit'] """ action_mask = inputs['obs']['action_mask'] x = self.actor(inputs['obs']['agent_state']) return {'logit': x['logit'], 'action_mask': action_mask} def compute_critic(self, inputs: Dict) -> Dict: - r""" + """ Overview: - Execute parameter updates with ``'compute_critic'`` mode - Use encoded embedding tensor to predict output. + use observation tensor to predict Q value. Arguments: - - ``obs``, ``action`` encoded tensors. - - mode (:obj:`str`): Name of the forward mode. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ + with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ + with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ + with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. Returns: - - outputs (:obj:`Dict`): Q-value output. - ReturnKeys: - - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, whose \ + key-values vary in different forward modes. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): Q value tensor is the shape of :math:`(B, A, N2)`, where B is batch size \ + and A is agent num. N2 corresponds to ``action_shape``. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N2)`, where B is batch size and \ + A is agent num. N2 corresponds to ``action_shape``. Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` - - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` - - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, whose \ + key-values vary in different forward modes. + if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N2)`, where B is batch size and \ + A is agent num. N2 corresponds to ``action_shape``. + if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + Examples: + >>> B = 32 + >>> agent_obs_shape = 216 + >>> global_obs_shape = 264 + >>> agent_num = 8 + >>> action_shape = 14 + >>> data = { + >>> 'obs': { + >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), + >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), + >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) + >>> } + >>> } + >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True) + >>> value = model.compute_critic(data)['q_value'] """ if self.twin_critic: @@ -191,7 +290,10 @@ def compute_critic(self, inputs: Dict) -> Dict: class ContinuousMAQAC(nn.Module): """ Overview: - The continuous action Multi-Agent Q-value Actor-CritiC (MAQAC) model. + The neural network and computation graph of algorithms related to continuous action Multi-Agent Q-value \ + Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \ + critic is a MLP network. The actor network is used to predict the action probability distribution, and the \ + critic network is used to predict the Q value of the state-action pair. Interfaces: ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` """ @@ -213,7 +315,7 @@ def __init__( ) -> None: """ Overview: - Init the QAC Model according to arguments. + Initialize the QAC Model according to arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, ) @@ -293,101 +395,131 @@ def __init__( ) def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: - r""" + """ Overview: - Use observation and action tensor to predict output. - Parameter updates with QAC's MLPs forward setup. + Use observation and action tensor to predict output in ``'compute_actor'`` or ``'compute_critic'`` mode. Arguments: - Forward with ``'compute_actor'``: - - inputs (:obj:`torch.Tensor`): - The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. - Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. - - Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys: - - ``obs``, ``action`` encoded tensors. - + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ + with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ + with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ + with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - ``action`` (:obj:`torch.Tensor`): The action tensor data, \ + with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \ + N3 corresponds to ``action_shape``. - mode (:obj:`str`): Name of the forward mode. Returns: - outputs (:obj:`Dict`): Outputs of network forward. - - Forward with ``'compute_actor'``, Necessary Keys (either): - - action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``. - - logit (:obj:`torch.Tensor`): - Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``. - - Forward with ``'compute_critic'``, Necessary Keys: - - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. - Actor Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` - - action (:obj:`torch.Tensor`): :math:`(B, N0)` - - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. - - Critic Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` - - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape`` - - logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N3 is ``action_shape`` - - Actor Examples: - >>> # Regression mode - >>> model = ContinuousQAC(64, 64, 'regression') - >>> inputs = torch.randn(4, 64) - >>> actor_outputs = model(inputs,'compute_actor') - >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) - >>> # Reparameterization Mode - >>> model = ContinuousQAC(64, 64, 'reparameterization') - >>> inputs = torch.randn(4, 64) - >>> actor_outputs = model(inputs,'compute_actor') - >>> actor_outputs['logit'][0].shape # mu - >>> torch.Size([4, 64]) - >>> actor_outputs['logit'][1].shape # sigma - >>> torch.Size([4, 64]) - + Forward with ``'compute_actor'``, if action_space == 'regression', Necessary Keys: + - action (:obj:`torch.Tensor`): Action tensor with same size as ``action_shape``. + Forward with ``'compute_actor'``, if action_space == 'reparameterization', Necessary Keys: + - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \ + A is agent num. N3 corresponds to ``action_shape``. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \ + A is agent num. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num. + Shapes: + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - ``action`` (:obj:`torch.Tensor`): :math:`(B, A, N3)`, where B is batch size and A is agent num. \ + N3 corresponds to ``action_shape``. + - outputs (:obj:`Dict`): Outputs of network forward. + Forward with ``'compute_actor'``, if action_space == 'regression', Necessary Keys: + - action (:obj:`torch.Tensor`): :math:`(B, A, N3)`, where B is batch size and A is agent num. \ + N3 corresponds to ``action_shape``. + Forward with ``'compute_actor'``, if action_space == 'reparameterization', Necessary Keys: + - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \ + A is agent num. N3 corresponds to ``action_shape``. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \ + A is agent num. + Forward with ``'compute_critic'``, if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num. + Examples: + >>> B = 32 + >>> agent_obs_shape = 216 + >>> global_obs_shape = 264 + >>> agent_num = 8 + >>> action_shape = 14 + >>> action_space = 'regression' + >>> # or + >>> action_space = 'reparameterization' + >>> data = { + >>> 'obs': { + >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), + >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), + >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) + >>> }, + >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) + >>> } + >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, action_space, twin_critic=False) + >>> if action_space == 'regression': + >>> action = model(data['obs'], mode='compute_actor')['action'] + >>> elif action_space == 'reparameterization': + >>> (mu, sigma) = model(data['obs'], mode='compute_actor')['logit'] + >>> value = model(data, mode='compute_critic')['q_value'] """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_actor(self, inputs: Dict) -> Dict: - r""" + """ Overview: - Use encoded embedding tensor to predict output. - Execute parameter updates with ``'compute_actor'`` mode - Use encoded embedding tensor to predict output. + Use observation tensor to predict action logits. Arguments: - - inputs (:obj:`torch.Tensor`): - The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. - ``hidden_size = actor_head_hidden_size`` - - mode (:obj:`str`): Name of the forward mode. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ + with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. Returns: - - outputs (:obj:`Dict`): Outputs of forward pass encoder and head. - - ReturnsKeys (either): - - action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``. - - logit (:obj:`torch.Tensor`): - Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``. - - logit + action_args + - outputs (:obj:`Dict`): Outputs of network forward. + if action_space == 'regression', Necessary Keys: + - action (:obj:`torch.Tensor`): Action tensor with same size as ``action_shape``. + if action_space == 'reparameterization', Necessary Keys: + - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \ + A is agent num. N3 corresponds to ``action_shape``. Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` - - action (:obj:`torch.Tensor`): :math:`(B, N0)` - - logit (:obj:`Union[list, torch.Tensor]`): - - case1(continuous space, list): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`. - - case2(hybrid space, torch.Tensor): :math:`(B, N1)`, where N1 is action_type_shape - - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size. - - action_args (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where N2 is action_args_shape - (action_args are continuous real value) + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - outputs (:obj:`Dict`): Outputs of network forward. + if action_space == 'regression', Necessary Keys: + - action (:obj:`torch.Tensor`): :math:`(B, A, N3)`, where B is batch size and A is agent num. \ + N3 corresponds to ``action_shape``. + if action_space == 'reparameterization', Necessary Keys: + - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \ + A is agent num. N3 corresponds to ``action_shape``. Examples: - >>> # Regression mode - >>> model = ContinuousQAC(64, 64, 'regression') - >>> inputs = torch.randn(4, 64) - >>> actor_outputs = model(inputs,'compute_actor') - >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) - >>> # Reparameterization Mode - >>> model = ContinuousQAC(64, 64, 'reparameterization') - >>> inputs = torch.randn(4, 64) - >>> actor_outputs = model(inputs,'compute_actor') - >>> actor_outputs['logit'][0].shape # mu - >>> torch.Size([4, 64]) - >>> actor_outputs['logit'][1].shape # sigma - >>> torch.Size([4, 64]) + >>> B = 32 + >>> agent_obs_shape = 216 + >>> global_obs_shape = 264 + >>> agent_num = 8 + >>> action_shape = 14 + >>> action_space = 'regression' + >>> # or + >>> action_space = 'reparameterization' + >>> data = { + >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), + >>> } + >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, action_space, twin_critic=False) + >>> if action_space == 'regression': + >>> action = model.compute_actor(data)['action'] + >>> elif action_space == 'reparameterization': + >>> (mu, sigma) = model.compute_actor(data)['logit'] """ inputs = inputs['agent_state'] if self.action_space == 'regression': @@ -398,28 +530,67 @@ def compute_actor(self, inputs: Dict) -> Dict: return {'logit': [x['mu'], x['sigma']]} def compute_critic(self, inputs: Dict) -> Dict: - r""" + """ Overview: - Execute parameter updates with ``'compute_critic'`` mode - Use encoded embedding tensor to predict output. + Use observation tensor and action tensor to predict Q value. Arguments: - - inputs (:obj: `Dict`): ``obs``, ``action`` and ``logit` tensors. - - mode (:obj:`str`): Name of the forward mode. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \ + with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \ + with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \ + with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - ``action`` (:obj:`torch.Tensor`): The action tensor data, \ + with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \ + N3 corresponds to ``action_shape``. Returns: - - outputs (:obj:`Dict`): Q-value output. - - ArgumentsKeys: - - necessary: - - obs: (:obj:`torch.Tensor`): 2-dim vector observation - - action (:obj:`Union[torch.Tensor, Dict]`): action from actor - - optional: - - logit (:obj:`torch.Tensor`): discrete action logit - ReturnKeys: - - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + - outputs (:obj:`Dict`): Outputs of network forward. + if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \ + A is agent num. + if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num. Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` - - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` - - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys: + - ``agent_state`` (:obj:`torch.Tensor`): :math:`(B, A, N0)`, where B is batch size and A is agent num. \ + N0 corresponds to ``agent_obs_shape``. + - ``global_state`` (:obj:`torch.Tensor`): :math:`(B, A, N1)`, where B is batch size and A is agent num. \ + N1 corresponds to ``global_obs_shape``. + - ``action_mask`` (:obj:`torch.Tensor`): :math:`(B, A, N2)`, where B is batch size and A is agent num. \ + N2 corresponds to ``action_shape``. + - ``action`` (:obj:`torch.Tensor`): :math:`(B, A, N3)`, where B is batch size and A is agent num. \ + N3 corresponds to ``action_shape``. + - outputs (:obj:`Dict`): Outputs of network forward. + if ``twin_critic`` is ``True``, Necessary Keys: + - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \ + A is agent num. + if ``twin_critic`` is ``False``, Necessary Keys: + - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num. + Examples: + >>> B = 32 + >>> agent_obs_shape = 216 + >>> global_obs_shape = 264 + >>> agent_num = 8 + >>> action_shape = 14 + >>> action_space = 'regression' + >>> # or + >>> action_space = 'reparameterization' + >>> data = { + >>> 'obs': { + >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape), + >>> 'global_state': torch.randn(B, agent_num, global_obs_shape), + >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape)) + >>> }, + >>> 'action': torch.randn(B, agent_num, squeeze(action_shape)) + >>> } + >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, action_space, twin_critic=False) + >>> value = model.compute_critic(data)['q_value'] """ obs, action = inputs['obs']['global_state'], inputs['action'] diff --git a/ding/model/template/ngu.py b/ding/model/template/ngu.py index d1c9d9fb99..caa3c14760 100644 --- a/ding/model/template/ngu.py +++ b/ding/model/template/ngu.py @@ -129,7 +129,7 @@ def __init__( ) def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict: - r""" + """ Overview: Forward computation graph of NGU R2D2 network. Input observation, prev_action prev_reward_extrinsic \ to predict NGU Q output. Parameter updates with NGU's MLPs forward setup. diff --git a/ding/model/template/ppg.py b/ding/model/template/ppg.py index f9dd64d21e..76df579e71 100644 --- a/ding/model/template/ppg.py +++ b/ding/model/template/ppg.py @@ -8,6 +8,15 @@ @MODEL_REGISTRY.register('ppg') class PPG(nn.Module): + """ + Overview: + Phasic Policy Gradient (PPG) model from paper `Phasic Policy Gradient` + https://arxiv.org/abs/2009.04416 \ + This module contains VAC module and an auxiliary critic module. + Interfaces: + ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic`` + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] def __init__( @@ -25,6 +34,27 @@ def __init__( norm_type: Optional[str] = None, impala_cnn_encoder: bool = False, ) -> None: + """ + Overview: + Initailize the PPG Model according to input arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ). + - action_shape (:obj:`Union[int, SequenceType]`): Action's shape, such as 4, (3, ). + - action_space (:obj:`str`): The action space type, such as 'discrete', 'continuous'. + - share_encoder (:obj:`bool`): Whether to share encoder. + - encoder_hidden_size_list (:obj:`SequenceType`): The hidden size list of encoder. + - actor_head_hidden_size (:obj:`int`): The ``hidden_size`` to pass to actor head. + - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ + for actor head. + - critic_head_hidden_size (:obj:`int`): The ``hidden_size`` to pass to critic head. + - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \ + for critic head. + - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \ + after each FC layer, if ``None`` then default set to ``nn.ReLU()``. + - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \ + see ``ding.torch_utils.network`` for more details. + - impala_cnn_encoder (:obj:`bool`): Whether to use impala cnn encoder. + """ super(PPG, self).__init__() self.actor_critic = VAC( obs_shape, @@ -43,20 +73,53 @@ def __init__( self.aux_critic = copy.deepcopy(self.actor_critic.critic) def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: + """ + Overview: + Compute action logits or value according to mode being ``compute_actor``, ``compute_critic`` or \ + ``compute_actor_critic``. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. + Returns: + - outputs (:obj:`Dict`): The output dict of PPG's forward computation graph, whose key-values vary from \ + different ``mode``. + """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_actor(self, x: torch.Tensor) -> Dict: """ + Overview: + Use actor to compute action logits. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - output (:obj:`Dict`): The output data containing action logits. ReturnsKeys: - - necessary: ``logit`` + - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ + the same dimension real-value ranged tensor of possible action choices, and for continuous action \ + space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ + same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ + and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is the input feature size. + - output (:obj:`Dict`): ``logit``: :math:`(B, A)`, where B is batch size and A is the action space size. """ return self.actor_critic(x, mode='compute_actor') def compute_critic(self, x: torch.Tensor) -> Dict: """ + Overview: + Use critic to compute value. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - output (:obj:`Dict`): The output dict of VAC's forward computation graph for critic, including ``value``. ReturnsKeys: - necessary: ``value`` + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is the input feature size. + - output (:obj:`Dict`): ``value``: :math:`(B, 1)`, where B is batch size. """ x = self.aux_critic[0](x) # encoder x = self.aux_critic[1](x) # head @@ -64,10 +127,26 @@ def compute_critic(self, x: torch.Tensor) -> Dict: def compute_actor_critic(self, x: torch.Tensor) -> Dict: """ + Overview: + Use actor and critic to compute action logits and value. + Arguments: + - x (:obj:`torch.Tensor`): The input observation tensor data. + Returns: + - outputs (:obj:`Dict`): The output dict of PPG's forward computation graph for both actor and critic, \ + including ``logit`` and ``value``. + ReturnsKeys: + - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ + the same dimension real-value ranged tensor of possible action choices, and for continuous action \ + space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ + same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \ + and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``. + - value (:obj:`torch.Tensor`): The predicted state value tensor. + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is the input feature size. + - output (:obj:`Dict`): ``value``: :math:`(B, 1)`, where B is batch size. + - output (:obj:`Dict`): ``logit``: :math:`(B, A)`, where B is batch size and A is the action space size. + .. note:: ``compute_actor_critic`` interface aims to save computation when shares encoder. - - ReturnsKeys: - - necessary: ``value``, ``logit`` """ return self.actor_critic(x, mode='compute_actor_critic') diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index 43c831c174..ece076bd81 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -199,7 +199,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> Dict: - r""" + """ Overview: BDQ forward computation graph, input observation tensor to predict q_value. Arguments: @@ -316,13 +316,6 @@ def __init__( def forward(self, x: torch.Tensor) -> Dict: """ - Returns: - - outputs (:obj:`Dict`): The output of DQN's forward, including q_value. - ReturnsKeys: - - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension. - Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` - - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` Overview: C51DQN forward computation graph, input observation tensor to predict q_value and its distribution. Arguments: @@ -337,7 +330,6 @@ def forward(self, x: torch.Tensor) -> Dict: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size. - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where M is action_shape. - distribution(:obj:`torch.Tensor`): :math:`(B, M, P)`, where P is n_atom. - Examples: >>> model = C51DQN(128, 64) # arguments: 'obs_shape' and 'action_shape' >>> inputs = torch.randn(4, 128) @@ -363,6 +355,14 @@ def forward(self, x: torch.Tensor) -> Dict: @MODEL_REGISTRY.register('qrdqn') class QRDQN(nn.Module): + """ + Overview: + The neural network structure and computation graph of QRDQN, which combines distributional RL and DQN. \ + You can refer to Distributional Reinforcement Learning with Quantile Regression \ + https://arxiv.org/pdf/1710.10044.pdf for more details. + Interfaces: + ``__init__``, ``forward`` + """ def __init__( self, @@ -375,9 +375,9 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, ) -> None: - r""" + """ Overview: - Init the QRDQN Model according to input arguments. + Initialize the QRDQN Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - action_shape (:obj:`Union[int, SequenceType]`): Action's space. @@ -429,7 +429,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> Dict: - r""" + """ Overview: Use observation tensor to predict QRDQN's output. Parameter updates with QRDQN's MLPs forward setup. @@ -439,7 +439,6 @@ def forward(self, x: torch.Tensor) -> Dict: Returns: - outputs (:obj:`Dict`): Run with encoder and head. Return the result prediction dictionary. - ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``. - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(B, N, num_quantiles)`` @@ -448,7 +447,6 @@ def forward(self, x: torch.Tensor) -> Dict: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size. - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape. - tau (:obj:`torch.Tensor`): :math:`(B, M, 1)` - Examples: >>> model = QRDQN(64, 64) >>> inputs = torch.randn(4, 64) @@ -466,6 +464,14 @@ def forward(self, x: torch.Tensor) -> Dict: @MODEL_REGISTRY.register('iqn') class IQN(nn.Module): + """ + Overview: + The neural network structure and computation graph of IQN, which combines distributional RL and DQN. \ + You can refer to paper Implicit Quantile Networks for Distributional Reinforcement Learning \ + https://arxiv.org/pdf/1806.06923.pdf for more details. + Interfaces: + ``__init__``, ``forward`` + """ def __init__( self, @@ -479,9 +485,9 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None ) -> None: - r""" + """ Overview: - Init the IQN Model according to input arguments. + Initialize the IQN Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape. - action_shape (:obj:`Union[int, SequenceType]`): Action space shape. @@ -536,7 +542,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> Dict: - r""" + """ Overview: Use encoded embedding tensor to predict IQN's output. Parameter updates with IQN's MLPs forward setup. @@ -546,7 +552,6 @@ def forward(self, x: torch.Tensor) -> Dict: Returns: - outputs (:obj:`Dict`): Run with encoder and head. Return the result prediction dictionary. - ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``. - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(num_quantiles, N, B)`` @@ -573,6 +578,14 @@ def forward(self, x: torch.Tensor) -> Dict: @MODEL_REGISTRY.register('fqf') class FQF(nn.Module): + """ + Overview: + The neural network structure and computation graph of FQF, which combines distributional RL and DQN. \ + You can refer to paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \ + https://arxiv.org/pdf/1911.02140.pdf for more details. + Interface: + ``__init__``, ``forward`` + """ def __init__( self, @@ -586,9 +599,9 @@ def __init__( activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None ) -> None: - r""" + """ Overview: - Init the FQF Model according to input arguments. + Initialize the FQF Model according to input arguments. Arguments: - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape. - action_shape (:obj:`Union[int, SequenceType]`): Action space shape. @@ -643,7 +656,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> Dict: - r""" + """ Overview: Use encoded embedding tensor to predict FQF's output. Parameter updates with FQF's MLPs forward setup. @@ -685,7 +698,11 @@ def forward(self, x: torch.Tensor) -> Dict: class RainbowDQN(nn.Module): """ Overview: - RainbowDQN network (C51 + Dueling + Noisy Block) + The neural network structure and computation graph of RainbowDQN, which combines distributional RL and DQN. \ + You can refer to paper Rainbow: Combining Improvements in Deep Reinforcement Learning \ + https://arxiv.org/pdf/1710.02298.pdf for more details. + Interfaces: + ``__init__``, ``forward`` .. note:: RainbowDQN contains dueling architecture by default. @@ -762,7 +779,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> Dict: - r""" + """ Overview: Use observation tensor to predict Rainbow output. Parameter updates with Rainbow's MLPs forward setup. @@ -772,7 +789,6 @@ def forward(self, x: torch.Tensor) -> Dict: Returns: - outputs (:obj:`Dict`): Run ``MLP`` with ``RainbowHead`` setups and return the result prediction dictionary. - ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``. - distribution (:obj:`torch.Tensor`): Distribution tensor of size ``(B, N, n_atom)`` @@ -780,7 +796,6 @@ def forward(self, x: torch.Tensor) -> Dict: - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size. - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape. - distribution(:obj:`torch.FloatTensor`): :math:`(B, M, P)`, where P is n_atom. - Examples: >>> model = RainbowDQN(64, 64) # arguments: 'obs_shape' and 'action_shape' >>> inputs = torch.randn(4, 64) @@ -796,7 +811,7 @@ def forward(self, x: torch.Tensor) -> Dict: def parallel_wrapper(forward_fn: Callable) -> Callable: - r""" + """ Overview: Process timestep T and batch_size B at the same time, in other words, treat different timestep data as different trajectories in a batch. @@ -941,7 +956,6 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Shapes: - obs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape`` - Examples: >>> # Init input's Keys: >>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4 diff --git a/ding/model/template/qac_dist.py b/ding/model/template/qac_dist.py index e2ce65f34c..d9390cb06e 100644 --- a/ding/model/template/qac_dist.py +++ b/ding/model/template/qac_dist.py @@ -8,7 +8,7 @@ @MODEL_REGISTRY.register('qac_dist') class QACDIST(nn.Module): - r""" + """ Overview: The QAC model with distributional Q-value. Interfaces: @@ -32,7 +32,7 @@ def __init__( v_max: Optional[float] = 10, n_atom: Optional[int] = 51, ) -> None: - r""" + """ Overview: Init the QAC Distributional Model according to arguments. Arguments: @@ -102,7 +102,7 @@ def __init__( ) def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: - r""" + """ Overview: Use observation and action tensor to predict output. Parameter updates with QACDIST's MLPs forward setup. @@ -166,7 +166,7 @@ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: return getattr(self, mode)(inputs) def compute_actor(self, inputs: torch.Tensor) -> Dict: - r""" + """ Overview: Use encoded embedding tensor to predict output. Execute parameter updates with ``'compute_actor'`` mode @@ -210,7 +210,7 @@ def compute_actor(self, inputs: torch.Tensor) -> Dict: return {'logit': [x['mu'], x['sigma']]} def compute_critic(self, inputs: Dict) -> Dict: - r""" + """ Overview: Execute parameter updates with ``'compute_critic'`` mode Use encoded embedding tensor to predict output. diff --git a/ding/model/template/vae.py b/ding/model/template/vae.py index c8d7546ddc..9839f0e905 100644 --- a/ding/model/template/vae.py +++ b/ding/model/template/vae.py @@ -83,7 +83,7 @@ def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]: return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding} def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]: - r""" + """ Overview: Maps the given latent action and obs_encoding onto the original action space. Arguments: @@ -108,7 +108,7 @@ def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]: return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: - r""" + """ Overview: Maps the given latent action and obs onto the original action space. Using the method self.encode_obs_head(obs) to get the obs_encoding. @@ -136,7 +136,7 @@ def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: - r""" + """ Overview: Reparameterization trick to sample from N(mu, var) from N(0,1). Arguments: diff --git a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py index 6205018751..2edb0a7c5a 100755 --- a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py @@ -29,7 +29,7 @@ transition_dim=14, dim=32, dim_mults=[1, 2, 4, 8], - returns_condition=False, + # returns_condition=False, kernel_size=5, attention=False, ), @@ -74,7 +74,8 @@ plan_batch_size=64, learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), ), - collect=dict(data_type='diffuser_traj', ), + # collect=dict(data_type='diffuser_traj', ), + collect=dict(data_type='diffuser_traj', data_path='/mnt/afs/niuyazhe/code/dataset/d4rl/hopper_medium_expert-v2.hdf5', ), eval=dict( evaluator=dict(eval_freq=500, ), test_ret=0.9, diff --git a/dizoo/d4rl/entry/d4rl_pd_main.py b/dizoo/d4rl/entry/d4rl_pd_main.py index 73e08288ed..71158c9cf4 100755 --- a/dizoo/d4rl/entry/d4rl_pd_main.py +++ b/dizoo/d4rl/entry/d4rl_pd_main.py @@ -1,6 +1,10 @@ from ding.entry import serial_pipeline_offline from ding.config import read_config from pathlib import Path +import os + +os.environ["LD_LIBRARY_PATH"] = "/mnt/afs/niuyazhe/code/.mujoco/mujoco210/bin:/usr/local/nvidia/lib64" +os.environ["MUJOCO_PY_MUJOCO_PATH"] = "/mnt/afs/niuyazhe/code/.mujoco/mujoco210" def train(args): @@ -16,6 +20,8 @@ def train(args): parser = argparse.ArgumentParser() parser.add_argument('--seed', '-s', type=int, default=10) - parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py') + # parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py') + parser.add_argument('--config', '-c', type=str, default='hopper_medium_expert_pd_config.py') + args = parser.parse_args() train(args) \ No newline at end of file