diff --git a/espnet2/asr/ctc.py b/espnet2/asr/ctc.py index 78fa431c458..64b87106ac8 100644 --- a/espnet2/asr/ctc.py +++ b/espnet2/asr/ctc.py @@ -10,7 +10,7 @@ class CTC(torch.nn.Module): Args: odim: dimension of outputs - encoder_output_sizse: number of encoder projection units + encoder_output_size: number of encoder projection units dropout_rate: dropout rate (0.0 ~ 1.0) ctc_type: builtin or warpctc reduce: reduce the CTC loss into a scalar @@ -19,7 +19,7 @@ class CTC(torch.nn.Module): def __init__( self, odim: int, - encoder_output_sizse: int, + encoder_output_size: int, dropout_rate: float = 0.0, ctc_type: str = "builtin", reduce: bool = True, @@ -27,7 +27,7 @@ def __init__( ): assert check_argument_types() super().__init__() - eprojs = encoder_output_sizse + eprojs = encoder_output_size self.dropout_rate = dropout_rate self.ctc_lo = torch.nn.Linear(eprojs, odim) self.ctc_type = ctc_type diff --git a/espnet2/tasks/asr.py b/espnet2/tasks/asr.py index ef198ccd5ae..780aa905697 100644 --- a/espnet2/tasks/asr.py +++ b/espnet2/tasks/asr.py @@ -477,7 +477,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: # 6. CTC ctc = CTC( - odim=vocab_size, encoder_output_sizse=encoder_output_size, **args.ctc_conf + odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf ) # 8. Build model diff --git a/espnet2/tasks/enh_asr.py b/espnet2/tasks/enh_asr.py index 49d83e26ee9..c452ab2201d 100644 --- a/espnet2/tasks/enh_asr.py +++ b/espnet2/tasks/enh_asr.py @@ -339,7 +339,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetEnhASRModel: # 6. CTC ctc = CTC( - odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf + odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf ) # 7. RNN-T Decoder (Not implemented) diff --git a/espnet2/tasks/st.py b/espnet2/tasks/st.py index d7b5a48c0c4..182a335cc56 100644 --- a/espnet2/tasks/st.py +++ b/espnet2/tasks/st.py @@ -516,7 +516,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetSTModel: if src_token_list is not None: ctc = CTC( odim=src_vocab_size, - encoder_output_sizse=encoder_output_size, + encoder_output_size=encoder_output_size, **args.ctc_conf, ) else: diff --git a/test/espnet2/asr/decoder/test_transformer_decoder.py b/test/espnet2/asr/decoder/test_transformer_decoder.py index df44bcc7e43..d01c5b07a64 100644 --- a/test/espnet2/asr/decoder/test_transformer_decoder.py +++ b/test/espnet2/asr/decoder/test_transformer_decoder.py @@ -217,7 +217,7 @@ def test_TransformerDecoder_batch_beam_search_online( use_output_layer=use_output_layer, linear_units=10, ) - ctc = CTC(odim=vocab_size, encoder_output_sizse=encoder_output_size) + ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size) ctc.to(dtype) ctc_scorer = CTCPrefixScorer(ctc=ctc, eos=vocab_size - 1) beam = BatchBeamSearchOnlineSim( diff --git a/test/espnet2/asr/encoder/test_conformer_encoder.py b/test/espnet2/asr/encoder/test_conformer_encoder.py index 43837c2ac4e..ddc2d077f9d 100644 --- a/test/espnet2/asr/encoder/test_conformer_encoder.py +++ b/test/espnet2/asr/encoder/test_conformer_encoder.py @@ -1,6 +1,7 @@ import pytest import torch +from espnet2.asr.ctc import CTC from espnet2.asr.encoder.conformer_encoder import ConformerEncoder @@ -17,12 +18,22 @@ ("legacy", "legacy_rel_pos", "legacy_rel_selfattn"), ], ) +@pytest.mark.parametrize( + "interctc_layer_idx, interctc_use_conditioning", + [ + ([], False), + ([1], False), + ([1], True), + ], +) def test_encoder_forward_backward( input_layer, positionwise_layer_type, rel_pos_type, pos_enc_layer_type, selfattention_layer_type, + interctc_layer_idx, + interctc_use_conditioning, ): encoder = ConformerEncoder( 20, @@ -39,13 +50,25 @@ def test_encoder_forward_backward( use_cnn_module=True, cnn_module_kernel=3, positionwise_layer_type=positionwise_layer_type, + interctc_layer_idx=interctc_layer_idx, + interctc_use_conditioning=interctc_use_conditioning, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 32]) else: x = torch.randn(2, 32, 20, requires_grad=True) x_lens = torch.LongTensor([32, 28]) - y, _, _ = encoder(x, x_lens) + if len(interctc_layer_idx) > 0: + ctc = None + if interctc_use_conditioning: + vocab_size = 5 + output_size = encoder.output_size() + ctc = CTC(odim=vocab_size, encoder_output_size=output_size) + encoder.conditioning_layer = torch.nn.Linear(vocab_size, output_size) + y, _, _ = encoder(x, x_lens, ctc=ctc) + y = y[0] + else: + y, _, _ = encoder(x, x_lens) y.sum().backward() @@ -82,6 +105,21 @@ def test_encoder_invalid_rel_pos_combination(): ) +def test_encoder_invalid_interctc_layer_idx(): + with pytest.raises(AssertionError): + ConformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[0, 1], + ) + with pytest.raises(AssertionError): + ConformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[1, 2], + ) + + def test_encoder_output_size(): encoder = ConformerEncoder(20, output_size=256) assert encoder.output_size() == 256 diff --git a/test/espnet2/asr/encoder/test_transformer_encoder.py b/test/espnet2/asr/encoder/test_transformer_encoder.py index 743595d5c37..fadb4fd2ef2 100644 --- a/test/espnet2/asr/encoder/test_transformer_encoder.py +++ b/test/espnet2/asr/encoder/test_transformer_encoder.py @@ -1,27 +1,68 @@ import pytest import torch +from espnet2.asr.ctc import CTC from espnet2.asr.encoder.transformer_encoder import TransformerEncoder @pytest.mark.parametrize("input_layer", ["linear", "conv2d", "embed", None]) @pytest.mark.parametrize("positionwise_layer_type", ["conv1d", "conv1d-linear"]) -def test_Encoder_forward_backward(input_layer, positionwise_layer_type): +@pytest.mark.parametrize( + "interctc_layer_idx, interctc_use_conditioning", + [ + ([], False), + ([1], False), + ([1], True), + ], +) +def test_Encoder_forward_backward( + input_layer, + positionwise_layer_type, + interctc_layer_idx, + interctc_use_conditioning, +): encoder = TransformerEncoder( 20, output_size=40, input_layer=input_layer, positionwise_layer_type=positionwise_layer_type, + interctc_layer_idx=interctc_layer_idx, + interctc_use_conditioning=interctc_use_conditioning, ) if input_layer == "embed": x = torch.randint(0, 10, [2, 10]) else: x = torch.randn(2, 10, 20, requires_grad=True) x_lens = torch.LongTensor([10, 8]) - y, _, _ = encoder(x, x_lens) + if len(interctc_layer_idx) > 0: + ctc = None + if interctc_use_conditioning: + vocab_size = 5 + output_size = encoder.output_size() + ctc = CTC(odim=vocab_size, encoder_output_size=output_size) + encoder.conditioning_layer = torch.nn.Linear(vocab_size, output_size) + y, _, _ = encoder(x, x_lens, ctc=ctc) + y = y[0] + else: + y, _, _ = encoder(x, x_lens) y.sum().backward() +def test_encoder_invalid_interctc_layer_idx(): + with pytest.raises(AssertionError): + TransformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[0, 1], + ) + with pytest.raises(AssertionError): + TransformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[1, 2], + ) + + def test_Encoder_output_size(): encoder = TransformerEncoder(20, output_size=256) assert encoder.output_size() == 256 diff --git a/test/espnet2/asr/test_ctc.py b/test/espnet2/asr/test_ctc.py index a218e6d6815..5e17121415d 100644 --- a/test/espnet2/asr/test_ctc.py +++ b/test/espnet2/asr/test_ctc.py @@ -18,15 +18,23 @@ def ctc_args(): def test_ctc_forward_backward(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") - ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) ctc(*ctc_args).sum().backward() +@pytest.mark.parametrize("ctc_type", ["builtin", "warpctc", "gtnctc"]) +def test_ctc_softmax(ctc_type, ctc_args): + if ctc_type == "warpctc": + pytest.importorskip("warpctc_pytorch") + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) + ctc.softmax(ctc_args[0]) + + @pytest.mark.parametrize("ctc_type", ["builtin", "warpctc", "gtnctc"]) def test_ctc_log_softmax(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") - ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) ctc.log_softmax(ctc_args[0]) @@ -34,5 +42,5 @@ def test_ctc_log_softmax(ctc_type, ctc_args): def test_ctc_argmax(ctc_type, ctc_args): if ctc_type == "warpctc": pytest.importorskip("warpctc_pytorch") - ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type) + ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type) ctc.argmax(ctc_args[0])