Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pu): fix last_linear_layer_weight_bias_init_zero in MLP and add its unittest #650

Merged
merged 8 commits into from
Apr 25, 2023
26 changes: 14 additions & 12 deletions ding/torch_utils/network/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def MLP(
dropout_probability: float = 0.5,
output_activation: nn.Module = None,
output_norm_type: str = None,
last_linear_layer_init_zero: bool = False
last_linear_layer_weight_bias_init_zero: bool = False
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
):
r"""
Overview:
Expand All @@ -335,8 +335,8 @@ def MLP(
- dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5.
- output_activation (:obj:`nn.Module`): the optional activation function in the last layer.
- output_norm_type (:obj:`str`): type of the normalization in the last layer.
- last_linear_layer_init_zero (:obj:`bool`): zero initialization for the last linear layer (including w and b),
which can provide stable zero outputs in the beginning.
- last_linear_layer_weight_bias_init_zero (:obj:`bool`): zero initialization for the last linear layer
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
(including w and b), which can provide stable zero outputs in the beginning.
Returns:
- block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block.

Expand Down Expand Up @@ -364,27 +364,29 @@ def MLP(
# the last layer
in_channels = channels[-2]
out_channels = channels[-1]
block.append(layer_fn(in_channels, out_channels))
if output_activation is None and output_norm_type is None:
# the last layer use the same norm and activation as front layers
block.append(layer_fn(in_channels, out_channels))
if norm_type is not None:
block.append(build_normalization(norm_type, dim=1)(out_channels))
if activation is not None:
block.append(activation)
if use_dropout:
block.append(nn.Dropout(dropout_probability))
else:
# the last layer use the specific norm and activation
block.append(layer_fn(in_channels, out_channels))
if output_norm_type is not None:
block.append(build_normalization(output_norm_type, dim=1)(out_channels))
if output_activation is not None:
block.append(output_activation)
if use_dropout:
block.append(nn.Dropout(dropout_probability))
if last_linear_layer_init_zero:
block[-2].weight.data.fill_(0)
block[-2].bias.data.fill_(0)
if use_dropout:
block.append(nn.Dropout(dropout_probability))

if last_linear_layer_weight_bias_init_zero:
# Locate the last linear layer and initialize its weights and biases to 0.
for _, layer in enumerate(reversed(block)):
if isinstance(layer, nn.Linear):
nn.init.zeros_(layer.weight)
nn.init.zeros_(layer.bias)
break

return sequential_pack(block)

Expand Down
70 changes: 54 additions & 16 deletions ding/torch_utils/network/tests/test_nn_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import pytest
from ding.torch_utils import build_activation, build_normalization
import torch
from torch.testing import assert_allclose

from ding.torch_utils import build_activation
from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \
ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \
normed_linear, normed_conv2d
Expand Down Expand Up @@ -44,20 +46,56 @@ def test_weight_init(self):
weight_init_(weight, 'xxx')

def test_mlp(self):
input = torch.rand(batch_size, in_channels).requires_grad_(True)
block = MLP(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
layer_num=2,
activation=torch.nn.ReLU(inplace=True),
norm_type='BN',
output_activation=torch.nn.Identity(),
output_norm_type=None,
last_linear_layer_init_zero=True
)
output = self.run_model(input, block)
assert output.shape == (batch_size, out_channels)
layer_num = 3
input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True)

# Test case 1: Simple MLP without dropout, normalization, or output activation
model = MLP(in_channels, hidden_channels, out_channels, layer_num)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 2: MLP with dropout and normalization
for norm_type in ["LN", "BN", None]:
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
use_dropout=True,
dropout_probability=0.5,
norm_type=norm_type
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

for act in [torch.nn.LeakyReLU(), torch.nn.ReLU(), torch.nn.Sigmoid(), None]:
for norm_type in ["LN", "BN", None]:
# Test case 3: MLP without last linear layer initialized to 0
model = MLP(
in_channels, hidden_channels, out_channels, layer_num, norm_type=norm_type, output_activation=act
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)

# Test case 4: MLP with last linear layer initialized to 0
model = MLP(
in_channels,
hidden_channels,
out_channels,
layer_num,
norm_type=norm_type,
output_activation=act,
last_linear_layer_weight_bias_init_zero=True
)
output_tensor = self.run_model(input_tensor, model)
assert output_tensor.shape == (batch_size, out_channels)
last_linear_layer = None
for layer in reversed(model):
if isinstance(layer, torch.nn.Linear):
last_linear_layer = layer
break
assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight))
assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias))

def test_conv1d_block(self):
length = 2
Expand Down