diff --git a/ding/torch_utils/network/nn_module.py b/ding/torch_utils/network/nn_module.py index b0d1daa5ba..600d00e2d0 100644 --- a/ding/torch_utils/network/nn_module.py +++ b/ding/torch_utils/network/nn_module.py @@ -314,8 +314,8 @@ def MLP( norm_type: str = None, use_dropout: bool = False, dropout_probability: float = 0.5, - output_activation: nn.Module = None, - output_norm_type: str = None, + output_activation: bool = True, + output_norm: bool = True, last_linear_layer_init_zero: bool = False ): r""" @@ -328,15 +328,18 @@ def MLP( - hidden_channels (:obj:`int`): Number of channels in the hidden tensor. - out_channels (:obj:`int`): Number of channels in the output tensor. - layer_num (:obj:`int`): Number of layers. - - layer_fn (:obj:`Callable`): layer function. - - activation (:obj:`nn.Module`): the optional activation function. - - norm_type (:obj:`str`): type of the normalization. - - use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block. - - dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5. - - output_activation (:obj:`nn.Module`): the optional activation function in the last layer. - - output_norm_type (:obj:`str`): type of the normalization in the last layer. - - last_linear_layer_init_zero (:obj:`bool`): zero initialization for the last linear layer (including w and b), - which can provide stable zero outputs in the beginning. + - layer_fn (:obj:`Callable`): Layer function. + - activation (:obj:`nn.Module`): The optional activation function. + - norm_type (:obj:`str`): The type of the normalization. + - use_dropout (:obj:`bool`): Whether to use dropout in the fully-connected block. + - dropout_probability (:obj:`float`): The probability of an element to be zeroed in the dropout. Default: 0.5. + - output_activation (:obj:`bool`): Whether to use activation in the output layer. If True, + we use the same activation as front layers. Default: True. + - output_norm (:obj:`bool`): Whether to use normalization in the output layer. If True, + we use the same normalization as front layers. Default: True. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last linear layer + (including w and b), which can provide stable zero outputs in the beginning, + usually used in the policy network in RL settings. Returns: - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block. @@ -361,30 +364,31 @@ def MLP( if use_dropout: block.append(nn.Dropout(dropout_probability)) - # the last layer + # The last layer in_channels = channels[-2] out_channels = channels[-1] - if output_activation is None and output_norm_type is None: - # the last layer use the same norm and activation as front layers - block.append(layer_fn(in_channels, out_channels)) + block.append(layer_fn(in_channels, out_channels)) + """ + In the final layer of a neural network, whether to use normalization and activation are typically determined + based on user specifications. These specifications depend on the problem at hand and the desired properties of + the model's output. + """ + if output_norm is True: + # The last layer uses the same norm as front layers. if norm_type is not None: block.append(build_normalization(norm_type, dim=1)(out_channels)) + if output_activation is True: + # The last layer uses the same activation as front layers. if activation is not None: block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - else: - # the last layer use the specific norm and activation - block.append(layer_fn(in_channels, out_channels)) - if output_norm_type is not None: - block.append(build_normalization(output_norm_type, dim=1)(out_channels)) - if output_activation is not None: - block.append(output_activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - if last_linear_layer_init_zero: - block[-2].weight.data.fill_(0) - block[-2].bias.data.fill_(0) + + if last_linear_layer_init_zero: + # Locate the last linear layer and initialize its weights and biases to 0. + for _, layer in enumerate(reversed(block)): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + nn.init.zeros_(layer.bias) + break return sequential_pack(block) diff --git a/ding/torch_utils/network/tests/test_nn_module.py b/ding/torch_utils/network/tests/test_nn_module.py index 394aa5856d..8fdc7845ee 100644 --- a/ding/torch_utils/network/tests/test_nn_module.py +++ b/ding/torch_utils/network/tests/test_nn_module.py @@ -1,6 +1,8 @@ -import torch import pytest -from ding.torch_utils import build_activation, build_normalization +import torch +from torch.testing import assert_allclose + +from ding.torch_utils import build_activation from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \ ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \ normed_linear, normed_conv2d @@ -44,20 +46,48 @@ def test_weight_init(self): weight_init_(weight, 'xxx') def test_mlp(self): - input = torch.rand(batch_size, in_channels).requires_grad_(True) - block = MLP( - in_channels=in_channels, - hidden_channels=hidden_channels, - out_channels=out_channels, - layer_num=2, - activation=torch.nn.ReLU(inplace=True), - norm_type='BN', - output_activation=torch.nn.Identity(), - output_norm_type=None, - last_linear_layer_init_zero=True - ) - output = self.run_model(input, block) - assert output.shape == (batch_size, out_channels) + layer_num = 3 + input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True) + + for output_activation in [True, False]: + for output_norm in [True, False]: + for activation in [torch.nn.ReLU(), torch.nn.LeakyReLU(), torch.nn.Tanh(), None]: + for norm_type in ["LN", "BN", None]: + # Test case 1: MLP without last linear layer initialized to 0. + model = MLP( + in_channels, + hidden_channels, + out_channels, + layer_num, + activation=activation, + norm_type=norm_type, + output_activation=output_activation, + output_norm=output_norm + ) + output_tensor = self.run_model(input_tensor, model) + assert output_tensor.shape == (batch_size, out_channels) + + # Test case 2: MLP with last linear layer initialized to 0. + model = MLP( + in_channels, + hidden_channels, + out_channels, + layer_num, + activation=activation, + norm_type=norm_type, + output_activation=output_activation, + output_norm=output_norm, + last_linear_layer_init_zero=True + ) + output_tensor = self.run_model(input_tensor, model) + assert output_tensor.shape == (batch_size, out_channels) + last_linear_layer = None + for layer in reversed(model): + if isinstance(layer, torch.nn.Linear): + last_linear_layer = layer + break + assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight)) + assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias)) def test_conv1d_block(self): length = 2