From 82ac5f751b8f09c99eb71822da375c8ae153fafb Mon Sep 17 00:00:00 2001 From: YosukeHiguchi Date: Mon, 28 Feb 2022 09:49:14 +0900 Subject: [PATCH 1/5] add tests --- .../asr/encoder/test_conformer_encoder.py | 22 +++++++++++++++++++ .../asr/encoder/test_transformer_encoder.py | 12 +++++++++- test/espnet2/asr/test_ctc.py | 8 +++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/test/espnet2/asr/encoder/test_conformer_encoder.py b/test/espnet2/asr/encoder/test_conformer_encoder.py index 43837c2ac4e..7849ef7f61f 100644 --- a/test/espnet2/asr/encoder/test_conformer_encoder.py +++ b/test/espnet2/asr/encoder/test_conformer_encoder.py @@ -17,12 +17,17 @@ ("legacy", "legacy_rel_pos", "legacy_rel_selfattn"), ], ) +@pytest.mark.parametrize( + "interctc_layer_idx, interctc_use_conditioning", [([1], False), ([1], True)] +) def test_encoder_forward_backward( input_layer, positionwise_layer_type, rel_pos_type, pos_enc_layer_type, selfattention_layer_type, + interctc_layer_idx, + interctc_use_conditioning, ): encoder = ConformerEncoder( 20, @@ -39,6 +44,8 @@ def test_encoder_forward_backward( use_cnn_module=True, cnn_module_kernel=3, positionwise_layer_type=positionwise_layer_type, + interctc_layer_idx=interctc_layer_idx + interctc_use_conditioning=interctc_use_conditioning, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 32]) @@ -82,6 +89,21 @@ def test_encoder_invalid_rel_pos_combination(): ) +def test_encoder_invalid_interctc_layer_idx(): + with pytest.raises(AssertionError): + ConformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[0, 1], + ) + with pytest.raises(AssertionError): + ConformerEncoder( + 20, + num_blocks=6, + interctc_layer_idx=[1, 2], + ) + + def test_encoder_output_size(): encoder = ConformerEncoder(20, output_size=256) assert encoder.output_size() == 256 diff --git a/test/espnet2/asr/encoder/test_transformer_encoder.py b/test/espnet2/asr/encoder/test_transformer_encoder.py index 82bb317dc9c..5fe221eb66f 100644 --- a/test/espnet2/asr/encoder/test_transformer_encoder.py +++ b/test/espnet2/asr/encoder/test_transformer_encoder.py @@ -6,12 +6,22 @@ @pytest.mark.parametrize("input_layer", ["linear", "conv2d", "embed", None]) @pytest.mark.parametrize("positionwise_layer_type", ["conv1d", "conv1d-linear"]) -def test_Encoder_forward_backward(input_layer, positionwise_layer_type): +@pytest.mark.parametrize( + "interctc_layer_idx, interctc_use_conditioning", [([1], False), ([1], True)] +) +def test_Encoder_forward_backward( + input_layer, + positionwise_layer_type, + interctc_layer_idx, + interctc_use_conditioning, +): encoder = TransformerEncoder( 20, output_size=40, input_layer=input_layer, positionwise_layer_type=positionwise_layer_type, + interctc_layer_idx=interctc_layer_idx, + interctc_use_conditioning=interctc_use_conditioning, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 10]) diff --git a/test/espnet2/asr/test_ctc.py b/test/espnet2/asr/test_ctc.py index a218e6d6815..6926ad31841 100644 --- a/test/espnet2/asr/test_ctc.py +++ b/test/espnet2/asr/test_ctc.py @@ -22,6 +22,14 @@ def test_ctc_forward_backward(ctc_type, ctc_args): ctc(*ctc_args).sum().backward() +@pytest.mark.parametrize("ctc_type", ["builtin", "warpctc", "gtnctc"]) +def test_ctc_softmax(ctc_type, ctc_args): + if ctc_type == "warpctc": + pytest.importorskip("warpctc_pytorch") + ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc.softmax(ctc_args[0]) + + @pytest.mark.parametrize("ctc_type", ["builtin", "warpctc", "gtnctc"]) def test_ctc_log_softmax(ctc_type, ctc_args): if ctc_type == "warpctc": From b39be4fea5208994f59e934747b70b18cd54f8db Mon Sep 17 00:00:00 2001 From: YosukeHiguchi Date: Mon, 28 Feb 2022 10:11:39 +0900 Subject: [PATCH 2/5] minor fix --- test/espnet2/asr/encoder/test_conformer_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/espnet2/asr/encoder/test_conformer_encoder.py b/test/espnet2/asr/encoder/test_conformer_encoder.py index 7849ef7f61f..321ced32d8f 100644 --- a/test/espnet2/asr/encoder/test_conformer_encoder.py +++ b/test/espnet2/asr/encoder/test_conformer_encoder.py @@ -44,7 +44,7 @@ def test_encoder_forward_backward( use_cnn_module=True, cnn_module_kernel=3, positionwise_layer_type=positionwise_layer_type, - interctc_layer_idx=interctc_layer_idx + interctc_layer_idx=interctc_layer_idx, interctc_use_conditioning=interctc_use_conditioning, ) if input_layer == "embed": From 1a51c373462542a9bce01822f2202ffce7d15214 Mon Sep 17 00:00:00 2001 From: YosukeHiguchi Date: Mon, 28 Feb 2022 10:59:01 +0900 Subject: [PATCH 3/5] fix typo sizse->size --- espnet2/asr/ctc.py | 6 +++--- espnet2/tasks/asr.py | 2 +- espnet2/tasks/enh_asr.py | 2 +- espnet2/tasks/st.py | 2 +- test/espnet2/asr/decoder/test_transformer_decoder.py | 2 +- test/espnet2/asr/test_ctc.py | 8 ++++---- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/espnet2/asr/ctc.py b/espnet2/asr/ctc.py index 78fa431c458..64b87106ac8 100644 --- a/espnet2/asr/ctc.py +++ b/espnet2/asr/ctc.py @@ -10,7 +10,7 @@ class CTC(torch.nn.Module): Args: odim: dimension of outputs - encoder_output_sizse: number of encoder projection units + encoder_output_size: number of encoder projection units dropout_rate: dropout rate (0.0 ~ 1.0) ctc_type: builtin or warpctc reduce: reduce the CTC loss into a scalar @@ -19,7 +19,7 @@ class CTC(torch.nn.Module): def __init__( self, odim: int, - encoder_output_sizse: int, + encoder_output_size: int, dropout_rate: float = 0.0, ctc_type: str = "builtin", reduce: bool = True, @@ -27,7 +27,7 @@ def __init__( ): assert check_argument_types() super().__init__() - eprojs = encoder_output_sizse + eprojs = encoder_output_size self.dropout_rate = dropout_rate self.ctc_lo = torch.nn.Linear(eprojs, odim) self.ctc_type = ctc_type diff --git a/espnet2/tasks/asr.py b/espnet2/tasks/asr.py index ef198ccd5ae..780aa905697 100644 --- a/espnet2/tasks/asr.py +++ b/espnet2/tasks/asr.py @@ -477,7 +477,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: # 6. CTC ctc = CTC( - odim=vocab_size, encoder_output_sizse=encoder_output_size, **args.ctc_conf + odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf ) # 8. Build model diff --git a/espnet2/tasks/enh_asr.py b/espnet2/tasks/enh_asr.py index 49d83e26ee9..c452ab2201d 100644 --- a/espnet2/tasks/enh_asr.py +++ b/espnet2/tasks/enh_asr.py @@ -339,7 +339,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetEnhASRModel: # 6. CTC ctc = CTC( - odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf + odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf ) # 7. RNN-T Decoder (Not implemented) diff --git a/espnet2/tasks/st.py b/espnet2/tasks/st.py index d7b5a48c0c4..182a335cc56 100644 --- a/espnet2/tasks/st.py +++ b/espnet2/tasks/st.py @@ -516,7 +516,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetSTModel: if src_token_list is not None: ctc = CTC( odim=src_vocab_size, - encoder_output_sizse=encoder_output_size, + encoder_output_size=encoder_output_size, **args.ctc_conf, ) else: diff --git a/test/espnet2/asr/decoder/test_transformer_decoder.py b/test/espnet2/asr/decoder/test_transformer_decoder.py index df44bcc7e43..d01c5b07a64 100644 --- a/test/espnet2/asr/decoder/test_transformer_decoder.py +++ b/test/espnet2/asr/decoder/test_transformer_decoder.py @@ -217,7 +217,7 @@ def test_TransformerDecoder_batch_beam_search_online( use_output_layer=use_output_layer, linear_units=10, ) - ctc = CTC(odim=vocab_size, encoder_output_sizse=encoder_output_size) + ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size) ctc.to(dtype) ctc_scorer = CTCPrefixScorer(ctc=ctc, eos=vocab_size - 1) beam = BatchBeamSearchOnlineSim( diff --git a/test/espnet2/asr/test_ctc.py b/test/espnet2/asr/test_ctc.py index 6926ad31841..5e17121415d 100644 --- a/test/espnet2/asr/test_ctc.py +++ b/test/espnet2/asr/test_ctc.py @@ -18,7 +18,7 @@ def ctc_args(): def test_ctc_forward_backward(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") - ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) ctc(*ctc_args).sum().backward() @@ -26,7 +26,7 @@ def test_ctc_forward_backward(ctc_type, ctc_args): def test_ctc_softmax(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") - ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) ctc.softmax(ctc_args[0]) @@ -34,7 +34,7 @@ def test_ctc_softmax(ctc_type, ctc_args): def test_ctc_log_softmax(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") - ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) ctc.log_softmax(ctc_args[0]) @@ -42,5 +42,5 @@ def test_ctc_log_softmax(ctc_type, ctc_args): def test_ctc_argmax(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") - ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) ctc.argmax(ctc_args[0]) From 3251ed4be12ff5efab939f2dec6577cd76627191 Mon Sep 17 00:00:00 2001 From: YosukeHiguchi Date: Mon, 28 Feb 2022 11:13:41 +0900 Subject: [PATCH 4/5] update tests --- .../asr/encoder/test_conformer_encoder.py | 24 ++++++++++-- .../asr/encoder/test_transformer_encoder.py | 37 ++++++++++++++++++- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/test/espnet2/asr/encoder/test_conformer_encoder.py b/test/espnet2/asr/encoder/test_conformer_encoder.py index 321ced32d8f..2a38323d635 100644 --- a/test/espnet2/asr/encoder/test_conformer_encoder.py +++ b/test/espnet2/asr/encoder/test_conformer_encoder.py @@ -1,6 +1,7 @@ import pytest import torch +from espnet2.asr.ctc import CTC from espnet2.asr.encoder.conformer_encoder import ConformerEncoder @@ -18,7 +19,12 @@ ], ) @pytest.mark.parametrize( - "interctc_layer_idx, interctc_use_conditioning", [([1], False), ([1], True)] + "interctc_layer_idx, interctc_use_conditioning", + [ + ([], False), + ([1], False), + ([1], True), + ], ) def test_encoder_forward_backward( input_layer, @@ -52,7 +58,19 @@ def test_encoder_forward_backward( else: x = torch.randn(2, 32, 20, requires_grad=True) x_lens = torch.LongTensor([32, 28]) - y, _, _ = encoder(x, x_lens) + if len(interctc_layer_idx) > 0: + ctc=None + if interctc_use_conditioning: + vocab_size = 5 + output_size = encoder.output_size() + ctc = CTC(odim=vocab_size, encoder_output_size=output_size) + encoder.conditioning_layer = torch.nn.Linear( + vocab_size, output_size + ) + y, _, _ = encoder(x, x_lens, ctc=ctc) + y = y[0] + else: + y, _, _ = encoder(x, x_lens) y.sum().backward() @@ -99,7 +117,7 @@ def test_encoder_invalid_interctc_layer_idx(): with pytest.raises(AssertionError): ConformerEncoder( 20, - num_blocks=6, + num_blocks=2, interctc_layer_idx=[1, 2], ) diff --git a/test/espnet2/asr/encoder/test_transformer_encoder.py b/test/espnet2/asr/encoder/test_transformer_encoder.py index 5fe221eb66f..1caf48d69af 100644 --- a/test/espnet2/asr/encoder/test_transformer_encoder.py +++ b/test/espnet2/asr/encoder/test_transformer_encoder.py @@ -1,13 +1,19 @@ import pytest import torch +from espnet2.asr.ctc import CTC from espnet2.asr.encoder.transformer_encoder import TransformerEncoder @pytest.mark.parametrize("input_layer", ["linear", "conv2d", "embed", None]) @pytest.mark.parametrize("positionwise_layer_type", ["conv1d", "conv1d-linear"]) @pytest.mark.parametrize( - "interctc_layer_idx, interctc_use_conditioning", [([1], False), ([1], True)] + "interctc_layer_idx, interctc_use_conditioning", + [ + ([], False), + ([1], False), + ([1], True), + ], ) def test_Encoder_forward_backward( input_layer, @@ -30,10 +36,37 @@ def test_Encoder_forward_backward( else: x = torch.randn(2, 10, 20, requires_grad=True) x_lens = torch.LongTensor([10, 8]) - y, _, _ = encoder(x, x_lens) + if len(interctc_layer_idx) > 0: + ctc=None + if interctc_use_conditioning: + vocab_size = 5 + output_size = encoder.output_size() + ctc = CTC(odim=vocab_size, encoder_output_size=output_size) + encoder.conditioning_layer = torch.nn.Linear( + vocab_size, output_size + ) + y, _, _ = encoder(x, x_lens, ctc=ctc) + y = y[0] + else: + y, _, _ = encoder(x, x_lens) y.sum().backward() +def test_encoder_invalid_interctc_layer_idx(): + with pytest.raises(AssertionError): + TransformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[0, 1], + ) + with pytest.raises(AssertionError): + TransformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[1, 2], + ) + + def test_Encoder_output_size(): encoder = TransformerEncoder(20, output_size=256) assert encoder.output_size() == 256 From 4604a2df662cd74ad739eac1364fcaacfba67153 Mon Sep 17 00:00:00 2001 From: YosukeHiguchi Date: Mon, 28 Feb 2022 11:30:53 +0900 Subject: [PATCH 5/5] apply black --- test/espnet2/asr/encoder/test_conformer_encoder.py | 6 ++---- test/espnet2/asr/encoder/test_transformer_encoder.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test/espnet2/asr/encoder/test_conformer_encoder.py b/test/espnet2/asr/encoder/test_conformer_encoder.py index 2a38323d635..ddc2d077f9d 100644 --- a/test/espnet2/asr/encoder/test_conformer_encoder.py +++ b/test/espnet2/asr/encoder/test_conformer_encoder.py @@ -59,14 +59,12 @@ def test_encoder_forward_backward( x = torch.randn(2, 32, 20, requires_grad=True) x_lens = torch.LongTensor([32, 28]) if len(interctc_layer_idx) > 0: - ctc=None + ctc = None if interctc_use_conditioning: vocab_size = 5 output_size = encoder.output_size() ctc = CTC(odim=vocab_size, encoder_output_size=output_size) - encoder.conditioning_layer = torch.nn.Linear( - vocab_size, output_size - ) + encoder.conditioning_layer = torch.nn.Linear(vocab_size, output_size) y, _, _ = encoder(x, x_lens, ctc=ctc) y = y[0] else: diff --git a/test/espnet2/asr/encoder/test_transformer_encoder.py b/test/espnet2/asr/encoder/test_transformer_encoder.py index 1caf48d69af..15bd6e0f331 100644 --- a/test/espnet2/asr/encoder/test_transformer_encoder.py +++ b/test/espnet2/asr/encoder/test_transformer_encoder.py @@ -37,14 +37,12 @@ def test_Encoder_forward_backward( x = torch.randn(2, 10, 20, requires_grad=True) x_lens = torch.LongTensor([10, 8]) if len(interctc_layer_idx) > 0: - ctc=None + ctc = None if interctc_use_conditioning: vocab_size = 5 output_size = encoder.output_size() ctc = CTC(odim=vocab_size, encoder_output_size=output_size) - encoder.conditioning_layer = torch.nn.Linear( - vocab_size, output_size - ) + encoder.conditioning_layer = torch.nn.Linear(vocab_size, output_size) y, _, _ = encoder(x, x_lens, ctc=ctc) y = y[0] else: