Skip to content

Commit

Permalink
doc(zjow): polish ding model common/template note (#741)
Browse files Browse the repository at this point in the history
* polish ding.model.template

* polish code
  • Loading branch information
zjowowen authored Nov 1, 2023
1 parent 439680a commit c5a4be3
Show file tree
Hide file tree
Showing 17 changed files with 1,447 additions and 400 deletions.
108 changes: 96 additions & 12 deletions ding/model/common/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def prod(iterable):
class ConvEncoder(nn.Module):
"""
Overview:
The ``Convolution Encoder`` used to encode raw 2-dim image observations (e.g. Atari/Procgen).
The Convolution Encoder is used to encode 2-dim image observations.
Interfaces:
``__init__``, ``forward``.
"""
Expand Down Expand Up @@ -106,6 +106,18 @@ def _get_flatten_size(self) -> int:
- outputs (:obj:`torch.Tensor`): Size ``int`` Tensor representing the number of ``in-features``.
Shapes:
- outputs: :math:`(1,)`.
Examples:
>>> conv = ConvEncoder(
>>> obs_shape=(4, 84, 84),
>>> hidden_size_list=[32, 64, 64, 128],
>>> activation=nn.ReLU(),
>>> kernel_size=[8, 4, 3],
>>> stride=[4, 2, 1],
>>> padding=None,
>>> layer_norm=False,
>>> norm_type=None
>>> )
>>> flatten_size = conv._get_flatten_size()
"""
test_data = torch.randn(1, *self.obs_shape)
with torch.no_grad():
Expand All @@ -123,6 +135,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Shapes:
- x : :math:`(B, C, H, W)`, where ``B`` is batch size, ``C`` is channel, ``H`` is height, ``W`` is width.
- outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]`` .
Examples:
>>> conv = ConvEncoder(
>>> obs_shape=(4, 84, 84),
>>> hidden_size_list=[32, 64, 64, 128],
>>> activation=nn.ReLU(),
>>> kernel_size=[8, 4, 3],
>>> stride=[4, 2, 1],
>>> padding=None,
>>> layer_norm=False,
>>> norm_type=None
>>> )
>>> x = torch.randn(1, 4, 84, 84)
>>> output = conv(x)
"""
x = self.main(x)
x = self.mid(x)
Expand All @@ -132,7 +157,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class FCEncoder(nn.Module):
"""
Overview:
The ``FCEncoder`` used in models to encode raw 1-dim observations (e.g. MuJoCo).
The full connected encoder is used to encode 1-dim input variable.
Interfaces:
``__init__``, ``forward``.
"""
Expand All @@ -148,7 +173,7 @@ def __init__(
) -> None:
"""
Overview:
Init the FC Encoder according to arguments.
Initialize the FC Encoder according to arguments.
Arguments:
- obs_shape (:obj:`int`): Observation shape.
- hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent FC layers.
Expand Down Expand Up @@ -194,6 +219,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Shapes:
- x : :math:`(B, M)`, where ``M = obs_shape``.
- outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]``.
Examples:
>>> fc = FCEncoder(
>>> obs_shape=4,
>>> hidden_size_list=[32, 64, 64, 128],
>>> activation=nn.ReLU(),
>>> norm_type=None,
>>> dropout=None
>>> )
>>> x = torch.randn(1, 4)
>>> output = fc(x)
"""
x = self.act(self.init(x))
x = self.main(x)
Expand All @@ -211,13 +246,18 @@ def __init__(self, obs_shape: Dict[str, Union[int, List[int]]]) -> None:
class IMPALACnnResidualBlock(nn.Module):
"""
Overview:
Residual basic block (without batchnorm) in IMPALA CNN encoder, which preserves the channel number and shape.
This CNN encoder residual block is residual basic block used in IMPALA algorithm,
which preserves the channel number and shape.
IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures
https://arxiv.org/pdf/1802.01561.pdf
Interfaces:
``__init__``, ``forward``.
"""

def __init__(self, in_channnel: int, scale: float = 1, batch_norm: bool = False):
"""
Overview:
Init the IMPALA CNN residual block according to arguments.
Initialize the IMPALA CNN residual block according to arguments.
Arguments:
- in_channnel (:obj:`int`): Channel number of input features.
- scale (:obj:`float`): Scale of module, defaults to 1.
Expand All @@ -234,9 +274,16 @@ def __init__(self, in_channnel: int, scale: float = 1, batch_norm: bool = False)
self.bn1 = nn.BatchNorm2d(self.in_channnel)

def residual(self, x: torch.Tensor) -> torch.Tensor:
# inplace should be False for the first relu, so that it does not change the input,
# which will be used for skip connection.
# getattr is for backwards compatibility with loaded models
"""
Overview:
Return output tensor of the residual block, keep the shape and channel number unchanged.
The inplace of activation function should be False for the first relu,
so that it does not change the origin input tensor of the residual block.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
Returns:
- output (:obj:`torch.Tensor`): Output tensor.
"""
if self.batch_norm:
x = self.bn0(x)
x = F.relu(x, inplace=False)
Expand All @@ -255,20 +302,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
- x (:obj:`torch.Tensor`): Input tensor.
Returns:
- output (:obj:`torch.Tensor`): Output tensor.
Examples:
>>> block = IMPALACnnResidualBlock(16)
>>> x = torch.randn(1, 16, 84, 84)
>>> output = block(x)
"""
return x + self.residual(x)


class IMPALACnnDownStack(nn.Module):
"""
Overview:
Downsampling stack from IMPALA CNN encoder, which reduces the spatial size by 2 with maxpooling.
Downsampling stack of CNN encoder used in IMPALA algorithmn.
Every IMPALACnnDownStack consists n IMPALACnnResidualBlock,
which reduces the spatial size by 2 with maxpooling.
IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures
https://arxiv.org/pdf/1802.01561.pdf
Interfaces:
``__init__``, ``forward``.
"""

def __init__(self, in_channnel, nblock, out_channel, scale=1, pool=True, **kwargs):
"""
Overview:
Init every impala cnn block of the Impala Cnn Encoder.
Initialize every impala cnn block of the Impala Cnn Encoder.
Arguments:
- in_channnel (:obj:`int`): Channel number of input features.
- nblock (:obj:`int`): Residual Block number in each block.
Expand All @@ -293,6 +350,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
- x (:obj:`torch.Tensor`): Input tensor.
Returns:
- output (:obj:`torch.Tensor`): Output tensor.
Examples:
>>> stack = IMPALACnnDownStack(16, 2, 32)
>>> x = torch.randn(1, 16, 84, 84)
>>> output = stack(x)
"""
x = self.firstconv(x)
if self.pool:
Expand All @@ -305,6 +366,17 @@ def output_shape(self, inshape: tuple) -> tuple:
"""
Overview:
Calculate the output shape of the downsampling stack according to input shape and related arguments.
Arguments:
- inshape (:obj:`tuple`): Input shape.
Returns:
- output_shape (:obj:`tuple`): Output shape.
Shapes:
- inshape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width.
- output_shape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width.
Examples:
>>> stack = IMPALACnnDownStack(16, 2, 32)
>>> inshape = (16, 84, 84)
>>> output_shape = stack.output_shape(inshape)
"""
c, h, w = inshape
assert c == self.in_channnel
Expand All @@ -317,7 +389,7 @@ def output_shape(self, inshape: tuple) -> tuple:
class IMPALAConvEncoder(nn.Module):
"""
Overview:
IMPALA CNN encoder, which is used in IMPALA algorithm. \
IMPALA CNN encoder, which is used in IMPALA algorithm.
IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, \
https://arxiv.org/pdf/1802.01561.pdf,
Interface:
Expand All @@ -337,7 +409,7 @@ def __init__(
) -> None:
"""
Overview:
Init the IMPALA CNN encoder according to arguments.
Initialize the IMPALA CNN encoder according to arguments.
Arguments:
- obs_shape (:obj:`SequenceType`): 2D image observation shape.
- channels (:obj:`SequenceType`): The channel number of a series of impala cnn blocks. \
Expand All @@ -348,6 +420,7 @@ def __init__(
observation, such as dividing 255.0 for the raw image observation.
- nblock (:obj:`int`): The number of Residual Block in each block.
- final_relu (:obj:`bool`): Whether to use ReLU activation in the final output of encoder.
- kwargs (:obj:`Dict[str, Any]`): Other arguments for ``IMPALACnnDownStack``.
"""
super().__init__()
self.scale_ob = scale_ob
Expand Down Expand Up @@ -375,6 +448,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
- x (:obj:`torch.Tensor`): :math:`(B, C, H, W)`, where B is batch size, C is channel number, H is height \
and W is width.
- output (:obj:`torch.Tensor`): :math:`(B, outsize)`, where B is batch size.
Examples:
>>> encoder = IMPALAConvEncoder(
>>> obs_shape=(4, 84, 84),
>>> channels=(16, 32, 32),
>>> outsize=256,
>>> scale_ob=255.0,
>>> nblock=2,
>>> final_relu=True,
>>> )
>>> x = torch.randn(1, 4, 84, 84)
>>> output = encoder(x)
"""
x = x / self.scale_ob
for (i, layer) in enumerate(self.stacks):
Expand Down
Loading

0 comments on commit c5a4be3

Please sign in to comment.