Skip to content

Commit

Permalink
Merge pull request espnet#4117 from YosukeHiguchi/espnet2_interctc
Browse files Browse the repository at this point in the history
Add tests for Intermediate/Self-conditioned CTC
  • Loading branch information
sw005320 authored Mar 1, 2022
2 parents 5edb478 + 4604a2d commit 7999009
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 13 deletions.
6 changes: 3 additions & 3 deletions espnet2/asr/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CTC(torch.nn.Module):
Args:
odim: dimension of outputs
encoder_output_sizse: number of encoder projection units
encoder_output_size: number of encoder projection units
dropout_rate: dropout rate (0.0 ~ 1.0)
ctc_type: builtin or warpctc
reduce: reduce the CTC loss into a scalar
Expand All @@ -19,15 +19,15 @@ class CTC(torch.nn.Module):
def __init__(
self,
odim: int,
encoder_output_sizse: int,
encoder_output_size: int,
dropout_rate: float = 0.0,
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = True,
):
assert check_argument_types()
super().__init__()
eprojs = encoder_output_sizse
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.ctc_type = ctc_type
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel:

# 6. CTC
ctc = CTC(
odim=vocab_size, encoder_output_sizse=encoder_output_size, **args.ctc_conf
odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
)

# 8. Build model
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/enh_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetEnhASRModel:

# 6. CTC
ctc = CTC(
odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf
odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
)

# 7. RNN-T Decoder (Not implemented)
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/st.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetSTModel:
if src_token_list is not None:
ctc = CTC(
odim=src_vocab_size,
encoder_output_sizse=encoder_output_size,
encoder_output_size=encoder_output_size,
**args.ctc_conf,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion test/espnet2/asr/decoder/test_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_TransformerDecoder_batch_beam_search_online(
use_output_layer=use_output_layer,
linear_units=10,
)
ctc = CTC(odim=vocab_size, encoder_output_sizse=encoder_output_size)
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size)
ctc.to(dtype)
ctc_scorer = CTCPrefixScorer(ctc=ctc, eos=vocab_size - 1)
beam = BatchBeamSearchOnlineSim(
Expand Down
40 changes: 39 additions & 1 deletion test/espnet2/asr/encoder/test_conformer_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.conformer_encoder import ConformerEncoder


Expand All @@ -17,12 +18,22 @@
("legacy", "legacy_rel_pos", "legacy_rel_selfattn"),
],
)
@pytest.mark.parametrize(
"interctc_layer_idx, interctc_use_conditioning",
[
([], False),
([1], False),
([1], True),
],
)
def test_encoder_forward_backward(
input_layer,
positionwise_layer_type,
rel_pos_type,
pos_enc_layer_type,
selfattention_layer_type,
interctc_layer_idx,
interctc_use_conditioning,
):
encoder = ConformerEncoder(
20,
Expand All @@ -39,13 +50,25 @@ def test_encoder_forward_backward(
use_cnn_module=True,
cnn_module_kernel=3,
positionwise_layer_type=positionwise_layer_type,
interctc_layer_idx=interctc_layer_idx,
interctc_use_conditioning=interctc_use_conditioning,
)
if input_layer == "embed":
x = torch.randint(0, 10, [2, 32])
else:
x = torch.randn(2, 32, 20, requires_grad=True)
x_lens = torch.LongTensor([32, 28])
y, _, _ = encoder(x, x_lens)
if len(interctc_layer_idx) > 0:
ctc = None
if interctc_use_conditioning:
vocab_size = 5
output_size = encoder.output_size()
ctc = CTC(odim=vocab_size, encoder_output_size=output_size)
encoder.conditioning_layer = torch.nn.Linear(vocab_size, output_size)
y, _, _ = encoder(x, x_lens, ctc=ctc)
y = y[0]
else:
y, _, _ = encoder(x, x_lens)
y.sum().backward()


Expand Down Expand Up @@ -82,6 +105,21 @@ def test_encoder_invalid_rel_pos_combination():
)


def test_encoder_invalid_interctc_layer_idx():
with pytest.raises(AssertionError):
ConformerEncoder(
20,
num_blocks=2,
interctc_layer_idx=[0, 1],
)
with pytest.raises(AssertionError):
ConformerEncoder(
20,
num_blocks=2,
interctc_layer_idx=[1, 2],
)


def test_encoder_output_size():
encoder = ConformerEncoder(20, output_size=256)
assert encoder.output_size() == 256
Expand Down
45 changes: 43 additions & 2 deletions test/espnet2/asr/encoder/test_transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,68 @@
import pytest
import torch

from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder


@pytest.mark.parametrize("input_layer", ["linear", "conv2d", "embed", None])
@pytest.mark.parametrize("positionwise_layer_type", ["conv1d", "conv1d-linear"])
def test_Encoder_forward_backward(input_layer, positionwise_layer_type):
@pytest.mark.parametrize(
"interctc_layer_idx, interctc_use_conditioning",
[
([], False),
([1], False),
([1], True),
],
)
def test_Encoder_forward_backward(
input_layer,
positionwise_layer_type,
interctc_layer_idx,
interctc_use_conditioning,
):
encoder = TransformerEncoder(
20,
output_size=40,
input_layer=input_layer,
positionwise_layer_type=positionwise_layer_type,
interctc_layer_idx=interctc_layer_idx,
interctc_use_conditioning=interctc_use_conditioning,
)
if input_layer == "embed":
x = torch.randint(0, 10, [2, 10])
else:
x = torch.randn(2, 10, 20, requires_grad=True)
x_lens = torch.LongTensor([10, 8])
y, _, _ = encoder(x, x_lens)
if len(interctc_layer_idx) > 0:
ctc = None
if interctc_use_conditioning:
vocab_size = 5
output_size = encoder.output_size()
ctc = CTC(odim=vocab_size, encoder_output_size=output_size)
encoder.conditioning_layer = torch.nn.Linear(vocab_size, output_size)
y, _, _ = encoder(x, x_lens, ctc=ctc)
y = y[0]
else:
y, _, _ = encoder(x, x_lens)
y.sum().backward()


def test_encoder_invalid_interctc_layer_idx():
with pytest.raises(AssertionError):
TransformerEncoder(
20,
num_blocks=2,
interctc_layer_idx=[0, 1],
)
with pytest.raises(AssertionError):
TransformerEncoder(
20,
num_blocks=2,
interctc_layer_idx=[1, 2],
)


def test_Encoder_output_size():
encoder = TransformerEncoder(20, output_size=256)
assert encoder.output_size() == 256
Expand Down
14 changes: 11 additions & 3 deletions test/espnet2/asr/test_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,29 @@ def ctc_args():
def test_ctc_forward_backward(ctc_type, ctc_args):
if ctc_type == "warpctc":
pytest.importorskip("warpctc_pytorch")
ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type)
ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type)
ctc(*ctc_args).sum().backward()


@pytest.mark.parametrize("ctc_type", ["builtin", "warpctc", "gtnctc"])
def test_ctc_softmax(ctc_type, ctc_args):
if ctc_type == "warpctc":
pytest.importorskip("warpctc_pytorch")
ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type)
ctc.softmax(ctc_args[0])


@pytest.mark.parametrize("ctc_type", ["builtin", "warpctc", "gtnctc"])
def test_ctc_log_softmax(ctc_type, ctc_args):
if ctc_type == "warpctc":
pytest.importorskip("warpctc_pytorch")
ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type)
ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type)
ctc.log_softmax(ctc_args[0])


@pytest.mark.parametrize("ctc_type", ["builtin", "warpctc", "gtnctc"])
def test_ctc_argmax(ctc_type, ctc_args):
if ctc_type == "warpctc":
pytest.importorskip("warpctc_pytorch")
ctc = CTC(encoder_output_sizse=10, odim=5, ctc_type=ctc_type)
ctc = CTC(encoder_output_size=10, odim=5, ctc_type=ctc_type)
ctc.argmax(ctc_args[0])

0 comments on commit 7999009

Please sign in to comment.