From bcc219a83de15a8cbd64dfe88e9276b26c4e6076 Mon Sep 17 00:00:00 2001 From: Liuwei Wei Date: Tue, 18 Oct 2022 14:34:28 +0800 Subject: [PATCH 1/5] add se_layer for blockformer --- .../aishell/s0/conf/train_blockformer.yaml | 82 +++++++++++++++++++ wenet/transformer/decoder.py | 81 +++++++++++++++--- wenet/transformer/decoder_layer.py | 5 +- wenet/transformer/encoder.py | 24 +++++- wenet/transformer/se_layer.py | 37 +++++++++ 5 files changed, 212 insertions(+), 17 deletions(-) create mode 100644 examples/aishell/s0/conf/train_blockformer.yaml create mode 100644 wenet/transformer/se_layer.py diff --git a/examples/aishell/s0/conf/train_blockformer.yaml b/examples/aishell/s0/conf/train_blockformer.yaml new file mode 100644 index 000000000..84df9357e --- /dev/null +++ b/examples/aishell/s0/conf/train_blockformer.yaml @@ -0,0 +1,82 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + use_se_module: true + se_module_channel: 12 # the same number with encoder blocks + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + input_layer: 'rel_embed' + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + use_se_module: true + se_module_channel: 6 # the same number with decoder blocks + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: false + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 360 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.002 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 50000 diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index c31853d9e..c77242cc6 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -20,9 +20,12 @@ from typeguard import check_argument_types from wenet.transformer.attention import MultiHeadedAttention +from wenet.transformer.attention import RelPositionMultiHeadedAttention from wenet.transformer.decoder_layer import DecoderLayer from wenet.transformer.embedding import PositionalEncoding +from wenet.transformer.embedding import RelPositionalEncoding from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.transformer.se_layer import SELayer from wenet.utils.mask import (subsequent_mask, make_pad_mask) @@ -61,6 +64,8 @@ def __init__( use_output_layer: bool = True, normalize_before: bool = True, concat_after: bool = False, + use_se_module: bool = False, + se_module_channel: int = 0 ): assert check_argument_types() super().__init__() @@ -71,10 +76,16 @@ def __init__( torch.nn.Embedding(vocab_size, attention_dim), PositionalEncoding(attention_dim, positional_dropout_rate), ) + elif input_layer == "rel_embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + RelPositionalEncoding(attention_dim, positional_dropout_rate), + ) else: raise ValueError(f"only 'embed' is supported: {input_layer}") self.normalize_before = normalize_before + self.input_layer = input_layer self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) self.use_output_layer = use_output_layer self.output_layer = torch.nn.Linear(attention_dim, vocab_size) @@ -82,8 +93,11 @@ def __init__( self.decoders = torch.nn.ModuleList([ DecoderLayer( attention_dim, - MultiHeadedAttention(attention_heads, attention_dim, - self_attention_dropout_rate), + RelPositionMultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate) \ + if input_layer == "rel_embed" else \ + MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), PositionwiseFeedForward(attention_dim, linear_units, @@ -93,6 +107,8 @@ def __init__( concat_after, ) for _ in range(self.num_blocks) ]) + self.use_se_module = use_se_module + self.se_class = SELayer(se_module_channel) def forward( self, @@ -130,10 +146,40 @@ def forward( device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m - x, _ = self.embed(tgt) - for layer in self.decoders: - x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, - memory_mask) + x, pos_emb = self.embed(tgt) + if self.use_se_module: + x_list = [] + for layer in self.decoders: + if self.input_layer == "rel_embed": + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + pos_emb) + else: + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + torch.empty(0)) + x_list.append(x) + x_list = torch.stack(x_list).transpose(0, 1) + x_se_output = self.se_class(x_list) + x = torch.sum(x_se_output, dim=1) + else: + for layer in self.decoders: + if self.input_layer == "rel_embed": + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + pos_emb) + else: + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + torch.empty(0)) if self.normalize_before: x = self.after_norm(x) if self.use_output_layer: @@ -163,18 +209,27 @@ def forward_one_step( y, cache: NN output value and cache per `self.decoders`. y.shape` is (batch, maxlen_out, token) """ - x, _ = self.embed(tgt) + x, pos_emb = self.embed(tgt) new_cache = [] for i, decoder in enumerate(self.decoders): if cache is None: c = None else: c = cache[i] - x, tgt_mask, memory, memory_mask = decoder(x, - tgt_mask, - memory, - memory_mask, - cache=c) + if self.input_layer == "rel_embed": + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + pos_emb, + cache=c) + else: + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + torch.empty(0), + cache=c) new_cache.append(x) if self.normalize_before: y = self.after_norm(x[:, -1]) @@ -222,6 +277,8 @@ def __init__( use_output_layer: bool = True, normalize_before: bool = True, concat_after: bool = False, + use_se_module: bool = False, + se_module_channel: int = 0 ): assert check_argument_types() diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 6b52aa6ab..15e8afd47 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -75,6 +75,7 @@ def forward( tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor, + pos_emb: torch.Tensor, cache: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute decoded features. @@ -117,11 +118,11 @@ def forward( if self.concat_after: tgt_concat = torch.cat( - (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1) + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]), dim=-1) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout( - self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]) if not self.normalize_before: x = self.norm1(x) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index bb2ec6582..2c1d26703 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -33,6 +33,7 @@ from wenet.transformer.subsampling import Conv2dSubsampling6 from wenet.transformer.subsampling import Conv2dSubsampling8 from wenet.transformer.subsampling import LinearNoSubsampling +from wenet.transformer.se_layer import SELayer from wenet.utils.common import get_activation from wenet.utils.mask import make_pad_mask from wenet.utils.mask import add_optional_chunk_mask @@ -57,6 +58,8 @@ def __init__( use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, + use_se_module: bool = False, + se_module_channel: int = 0 ): """ Args: @@ -127,6 +130,8 @@ def __init__( self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.use_se_module = use_se_module + self.se_class = SELayer(se_module_channel) def output_size(self) -> int: return self._output_size @@ -169,8 +174,18 @@ def forward( decoding_chunk_size, self.static_chunk_size, num_decoding_left_chunks) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.use_se_module: + xs_list = [] + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs_list.append(xs) + xs_list = torch.stack(xs_list).transpose(0, 1) + xs_se_output = self.se_class(xs_list) + xs = torch.sum(xs_se_output, dim=1) + else: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -397,6 +412,8 @@ def __init__( cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = "batch_norm", + use_se_module: bool = False, + se_module_channel: int = 0 ): """Construct ConformerEncoder @@ -420,7 +437,8 @@ def __init__( positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, concat_after, static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk) + global_cmvn, use_dynamic_left_chunk, use_se_module, + se_module_channel) activation = get_activation(activation_type) # self-attention module definition diff --git a/wenet/transformer/se_layer.py b/wenet/transformer/se_layer.py new file mode 100644 index 000000000..6ef971d30 --- /dev/null +++ b/wenet/transformer/se_layer.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 Mininglamp Com (Liuwei Wei, Xiaoming Ren) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + + +"""Squeeze-and-Excitation layer definition.""" + +import torch + + +class SELayer(torch.nn.Module): + def __init__(self, channel: int, reduction: int = 1): + super().__init__() + self.avg_pool = torch.nn.AdaptiveAvgPool2d(1) + self.fc = torch.nn.Sequential( + torch.nn.Linear(channel, channel // reduction, bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(channel // reduction, channel, bias=False), + torch.nn.Sigmoid() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) From 9fdb66204f6348d5741af4b48bd12eac71851066 Mon Sep 17 00:00:00 2001 From: Liuwei Wei Date: Tue, 18 Oct 2022 14:34:28 +0800 Subject: [PATCH 2/5] add se_layer for blockformer --- .../aishell/s0/conf/train_blockformer.yaml | 82 +++++++++++++++++++ wenet/transformer/decoder.py | 81 +++++++++++++++--- wenet/transformer/decoder_layer.py | 5 +- wenet/transformer/encoder.py | 24 +++++- wenet/transformer/se_layer.py | 37 +++++++++ 5 files changed, 212 insertions(+), 17 deletions(-) create mode 100644 examples/aishell/s0/conf/train_blockformer.yaml create mode 100644 wenet/transformer/se_layer.py diff --git a/examples/aishell/s0/conf/train_blockformer.yaml b/examples/aishell/s0/conf/train_blockformer.yaml new file mode 100644 index 000000000..84df9357e --- /dev/null +++ b/examples/aishell/s0/conf/train_blockformer.yaml @@ -0,0 +1,82 @@ +# network architecture +# encoder related +encoder: conformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'rel_pos' + selfattention_layer_type: 'rel_selfattn' + use_se_module: true + se_module_channel: 12 # the same number with encoder blocks + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + input_layer: 'rel_embed' + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + use_se_module: true + se_module_channel: 6 # the same number with decoder blocks + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: false + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 16 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 360 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.002 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 50000 diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index c31853d9e..c77242cc6 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -20,9 +20,12 @@ from typeguard import check_argument_types from wenet.transformer.attention import MultiHeadedAttention +from wenet.transformer.attention import RelPositionMultiHeadedAttention from wenet.transformer.decoder_layer import DecoderLayer from wenet.transformer.embedding import PositionalEncoding +from wenet.transformer.embedding import RelPositionalEncoding from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.transformer.se_layer import SELayer from wenet.utils.mask import (subsequent_mask, make_pad_mask) @@ -61,6 +64,8 @@ def __init__( use_output_layer: bool = True, normalize_before: bool = True, concat_after: bool = False, + use_se_module: bool = False, + se_module_channel: int = 0 ): assert check_argument_types() super().__init__() @@ -71,10 +76,16 @@ def __init__( torch.nn.Embedding(vocab_size, attention_dim), PositionalEncoding(attention_dim, positional_dropout_rate), ) + elif input_layer == "rel_embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + RelPositionalEncoding(attention_dim, positional_dropout_rate), + ) else: raise ValueError(f"only 'embed' is supported: {input_layer}") self.normalize_before = normalize_before + self.input_layer = input_layer self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) self.use_output_layer = use_output_layer self.output_layer = torch.nn.Linear(attention_dim, vocab_size) @@ -82,8 +93,11 @@ def __init__( self.decoders = torch.nn.ModuleList([ DecoderLayer( attention_dim, - MultiHeadedAttention(attention_heads, attention_dim, - self_attention_dropout_rate), + RelPositionMultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate) \ + if input_layer == "rel_embed" else \ + MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), PositionwiseFeedForward(attention_dim, linear_units, @@ -93,6 +107,8 @@ def __init__( concat_after, ) for _ in range(self.num_blocks) ]) + self.use_se_module = use_se_module + self.se_class = SELayer(se_module_channel) def forward( self, @@ -130,10 +146,40 @@ def forward( device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m - x, _ = self.embed(tgt) - for layer in self.decoders: - x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, - memory_mask) + x, pos_emb = self.embed(tgt) + if self.use_se_module: + x_list = [] + for layer in self.decoders: + if self.input_layer == "rel_embed": + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + pos_emb) + else: + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + torch.empty(0)) + x_list.append(x) + x_list = torch.stack(x_list).transpose(0, 1) + x_se_output = self.se_class(x_list) + x = torch.sum(x_se_output, dim=1) + else: + for layer in self.decoders: + if self.input_layer == "rel_embed": + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + pos_emb) + else: + x, tgt_mask, memory, memory_mask = layer(x, + tgt_mask, + memory, + memory_mask, + torch.empty(0)) if self.normalize_before: x = self.after_norm(x) if self.use_output_layer: @@ -163,18 +209,27 @@ def forward_one_step( y, cache: NN output value and cache per `self.decoders`. y.shape` is (batch, maxlen_out, token) """ - x, _ = self.embed(tgt) + x, pos_emb = self.embed(tgt) new_cache = [] for i, decoder in enumerate(self.decoders): if cache is None: c = None else: c = cache[i] - x, tgt_mask, memory, memory_mask = decoder(x, - tgt_mask, - memory, - memory_mask, - cache=c) + if self.input_layer == "rel_embed": + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + pos_emb, + cache=c) + else: + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + torch.empty(0), + cache=c) new_cache.append(x) if self.normalize_before: y = self.after_norm(x[:, -1]) @@ -222,6 +277,8 @@ def __init__( use_output_layer: bool = True, normalize_before: bool = True, concat_after: bool = False, + use_se_module: bool = False, + se_module_channel: int = 0 ): assert check_argument_types() diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 6b52aa6ab..15e8afd47 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -75,6 +75,7 @@ def forward( tgt_mask: torch.Tensor, memory: torch.Tensor, memory_mask: torch.Tensor, + pos_emb: torch.Tensor, cache: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute decoded features. @@ -117,11 +118,11 @@ def forward( if self.concat_after: tgt_concat = torch.cat( - (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1) + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]), dim=-1) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout( - self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]) if not self.normalize_before: x = self.norm1(x) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index bb2ec6582..2c1d26703 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -33,6 +33,7 @@ from wenet.transformer.subsampling import Conv2dSubsampling6 from wenet.transformer.subsampling import Conv2dSubsampling8 from wenet.transformer.subsampling import LinearNoSubsampling +from wenet.transformer.se_layer import SELayer from wenet.utils.common import get_activation from wenet.utils.mask import make_pad_mask from wenet.utils.mask import add_optional_chunk_mask @@ -57,6 +58,8 @@ def __init__( use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, + use_se_module: bool = False, + se_module_channel: int = 0 ): """ Args: @@ -127,6 +130,8 @@ def __init__( self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.use_se_module = use_se_module + self.se_class = SELayer(se_module_channel) def output_size(self) -> int: return self._output_size @@ -169,8 +174,18 @@ def forward( decoding_chunk_size, self.static_chunk_size, num_decoding_left_chunks) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.use_se_module: + xs_list = [] + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs_list.append(xs) + xs_list = torch.stack(xs_list).transpose(0, 1) + xs_se_output = self.se_class(xs_list) + xs = torch.sum(xs_se_output, dim=1) + else: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -397,6 +412,8 @@ def __init__( cnn_module_kernel: int = 15, causal: bool = False, cnn_module_norm: str = "batch_norm", + use_se_module: bool = False, + se_module_channel: int = 0 ): """Construct ConformerEncoder @@ -420,7 +437,8 @@ def __init__( positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, concat_after, static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk) + global_cmvn, use_dynamic_left_chunk, use_se_module, + se_module_channel) activation = get_activation(activation_type) # self-attention module definition diff --git a/wenet/transformer/se_layer.py b/wenet/transformer/se_layer.py new file mode 100644 index 000000000..6ef971d30 --- /dev/null +++ b/wenet/transformer/se_layer.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 Mininglamp Com (Liuwei Wei, Xiaoming Ren) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + + +"""Squeeze-and-Excitation layer definition.""" + +import torch + + +class SELayer(torch.nn.Module): + def __init__(self, channel: int, reduction: int = 1): + super().__init__() + self.avg_pool = torch.nn.AdaptiveAvgPool2d(1) + self.fc = torch.nn.Sequential( + torch.nn.Linear(channel, channel // reduction, bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(channel // reduction, channel, bias=False), + torch.nn.Sigmoid() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) From f85d5ef9b65ef274c2886b6b09f0b7586fbfa051 Mon Sep 17 00:00:00 2001 From: Liuwei Wei Date: Wed, 19 Oct 2022 10:56:50 +0800 Subject: [PATCH 3/5] fix formatting issues --- wenet/transformer/decoder.py | 4 ++-- wenet/transformer/decoder_layer.py | 4 +++- wenet/transformer/se_layer.py | 1 - 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index c77242cc6..45870b261 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -94,8 +94,8 @@ def __init__( DecoderLayer( attention_dim, RelPositionMultiHeadedAttention(attention_heads, attention_dim, - self_attention_dropout_rate) \ - if input_layer == "rel_embed" else \ + self_attention_dropout_rate) + if input_layer == "rel_embed" else MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate), MultiHeadedAttention(attention_heads, attention_dim, diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 15e8afd47..70e390d66 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -118,7 +118,9 @@ def forward( if self.concat_after: tgt_concat = torch.cat( - (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]), dim=-1) + (tgt_q, + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]), + dim=-1) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout( diff --git a/wenet/transformer/se_layer.py b/wenet/transformer/se_layer.py index 6ef971d30..96b3dfdb1 100644 --- a/wenet/transformer/se_layer.py +++ b/wenet/transformer/se_layer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Modified from ESPnet(https://github.com/espnet/espnet) """Squeeze-and-Excitation layer definition.""" From 6e6f759e0d439a28164b12e9d91a8c6a050db8c0 Mon Sep 17 00:00:00 2001 From: Liuwei Wei Date: Wed, 19 Oct 2022 11:59:30 +0800 Subject: [PATCH 4/5] fix formatting issues --- wenet/transformer/decoder.py | 6 +++--- wenet/transformer/decoder_layer.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 45870b261..323eb9b5f 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -95,9 +95,9 @@ def __init__( attention_dim, RelPositionMultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate) - if input_layer == "rel_embed" else - MultiHeadedAttention(attention_heads, attention_dim, - self_attention_dropout_rate), + if input_layer == "rel_embed" else MultiHeadedAttention( + attention_heads, attention_dim, + self_attention_dropout_rate), MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), PositionwiseFeedForward(attention_dim, linear_units, diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index 70e390d66..371223d51 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -118,9 +118,9 @@ def forward( if self.concat_after: tgt_concat = torch.cat( - (tgt_q, - self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]), - dim=-1) + (tgt_q, self.self_attn( + tgt_q, tgt, tgt, tgt_q_mask, pos_emb)[0]), + dim=-1) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout( From 89e17d48ee1329455ec79cc31ded6910e624f183 Mon Sep 17 00:00:00 2001 From: Liuwei Wei Date: Wed, 19 Oct 2022 15:08:47 +0800 Subject: [PATCH 5/5] add aishell results --- examples/aishell/s0/README.md | 13 +++++++++++++ examples/aishell/s0/run.sh | 1 + 2 files changed, 14 insertions(+) diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index 96675cbf4..ae8b534a0 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -131,3 +131,16 @@ | ctc prefix beam search | 4.94 | 4.97 | | attention rescoring | 4.61 | 4.69 | +## blockformer Result + +* Feature info: using fbank feature, dither, cmvn, online speed perturb +* Training info: lr 0.002, batch size 16, 8 gpu, acc_grad 4, 360 epochs, dither 0.1, warm up steps 50000 +* Decoding info: ctc_weight 0.5, average_num 30 + +| decoding mode | CER | +|---------------------------|-------| +| attention decoder | 4.78 | +| ctc greedy search | 4.74 | +| ctc prefix beam search | 4.75 | +| attention rescoring | 4.41 | + diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index e18b77c84..fdcb6704c 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -44,6 +44,7 @@ train_set=train # 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer # 5. conf/train_u2++_conformer.yaml: U2++ conformer # 6. conf/train_u2++_transformer.yaml: U2++ transformer +# 7. conf/train_blockformer.yaml: conformer with se_layer train_config=conf/train_conformer.yaml cmvn=true dir=exp/conformer