From 0c455db55dd015d5c2c258544fc351210cf643a5 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 17 Jun 2022 15:20:05 +0800 Subject: [PATCH 1/3] comment some lines, random combine from 1/3 layers, on linear layers in combiner --- .../pruned_transducer_stateless5/conformer.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 6f7231f4bb..e2ec2e1cf1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -87,10 +87,17 @@ def __init__( layer_dropout, cnn_module_kernel, ) + # aux_layers from 1/3 self.encoder = ConformerEncoder( encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ), ) def forward( @@ -296,10 +303,10 @@ def __init__( assert num_layers - 1 not in aux_layers self.aux_layers = set(aux_layers + [num_layers - 1]) - num_channels = encoder_layer.norm_final.num_channels + # num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( num_inputs=len(self.aux_layers), - num_channels=num_channels, + # num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0, @@ -1073,7 +1080,7 @@ class RandomCombine(nn.Module): def __init__( self, num_inputs: int, - num_channels: int, + # num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0, @@ -1116,12 +1123,12 @@ def __init__( assert 0 < final_weight < 1, final_weight assert num_inputs >= 1 - self.linear = nn.ModuleList( - [ - nn.Linear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1) - ] - ) + # self.linear = nn.ModuleList( + # [ + # nn.Linear(num_channels, num_channels, bias=True) + # for _ in range(num_inputs - 1) + # ] + # ) self.num_inputs = num_inputs self.final_weight = final_weight @@ -1135,12 +1142,13 @@ def __init__( .log() .item() ) - self._reset_parameters() - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) + # self._reset_parameters() + + # def _reset_parameters(self): + # for i in range(len(self.linear)): + # nn.init.eye_(self.linear[i].weight) + # nn.init.constant_(self.linear[i].bias, 0.0) def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. @@ -1163,7 +1171,8 @@ def forward(self, inputs: List[Tensor]) -> Tensor: mod_inputs = [] for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) + # mod_inputs.append(self.linear[i](inputs[i])) + mod_inputs.append(inputs[i]) mod_inputs.append(inputs[num_inputs - 1]) ndim = inputs[0].ndim From d1362a5a2fa7b2adfbb361b052b51c299c199115 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 28 Jun 2022 20:56:53 +0800 Subject: [PATCH 2/3] delete commented lines --- .../pruned_transducer_stateless5/conformer.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index e2ec2e1cf1..f924cb6125 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -303,10 +303,8 @@ def __init__( assert num_layers - 1 not in aux_layers self.aux_layers = set(aux_layers + [num_layers - 1]) - # num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( num_inputs=len(self.aux_layers), - # num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0, @@ -1080,7 +1078,6 @@ class RandomCombine(nn.Module): def __init__( self, num_inputs: int, - # num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0, @@ -1091,8 +1088,6 @@ def __init__( The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: - The number of channels on the input, e.g. 512. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing @@ -1123,13 +1118,6 @@ def __init__( assert 0 < final_weight < 1, final_weight assert num_inputs >= 1 - # self.linear = nn.ModuleList( - # [ - # nn.Linear(num_channels, num_channels, bias=True) - # for _ in range(num_inputs - 1) - # ] - # ) - self.num_inputs = num_inputs self.final_weight = final_weight self.pure_prob = pure_prob @@ -1143,13 +1131,6 @@ def __init__( .item() ) - # self._reset_parameters() - - # def _reset_parameters(self): - # for i in range(len(self.linear)): - # nn.init.eye_(self.linear[i].weight) - # nn.init.constant_(self.linear[i].bias, 0.0) - def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1171,7 +1152,6 @@ def forward(self, inputs: List[Tensor]) -> Tensor: mod_inputs = [] for i in range(num_inputs - 1): - # mod_inputs.append(self.linear[i](inputs[i])) mod_inputs.append(inputs[i]) mod_inputs.append(inputs[num_inputs - 1]) From 11f96e2d24ac522bf313103a83c6f8b95c236898 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 30 Jun 2022 12:11:09 +0800 Subject: [PATCH 3/3] minor change --- .../ASR/pruned_transducer_stateless5/conformer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 6ff0fbe8e7..49bc6a489a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -1149,13 +1149,9 @@ def forward(self, inputs: List[Tensor]) -> Tensor: num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels - mod_inputs = [] - for i in range(num_inputs): - mod_inputs.append(inputs[i]) - ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( (num_frames, num_channels, num_inputs) )