From 406656cdcb668a77910074b4382b557b6f845c54 Mon Sep 17 00:00:00 2001 From: earthmanylf <411214987@qq.com> Date: Thu, 28 Apr 2022 11:10:11 +0800 Subject: [PATCH] Add custom name in __init__ in tf_domain.py; Merge test_dpcl_loss.py to test_tf_domain.py --- espnet2/enh/loss/criterions/tf_domain.py | 7 ++-- .../enh/loss/criterions/test_dpcl_loss.py | 32 ------------------- .../enh/loss/criterions/test_tf_domain.py | 20 ++++++++++++ 3 files changed, 25 insertions(+), 34 deletions(-) delete mode 100644 test/espnet2/enh/loss/criterions/test_dpcl_loss.py diff --git a/espnet2/enh/loss/criterions/tf_domain.py b/espnet2/enh/loss/criterions/tf_domain.py index c4a383088cb..c94678e4244 100644 --- a/espnet2/enh/loss/criterions/tf_domain.py +++ b/espnet2/enh/loss/criterions/tf_domain.py @@ -226,11 +226,14 @@ def forward(self, ref, inf) -> torch.Tensor: class FrequencyDomainDPCL(FrequencyDomainLoss): - def __init__(self, compute_on_mask=False, mask_type="IBM", loss_type="dpcl"): + def __init__( + self, compute_on_mask=False, mask_type="IBM", loss_type="dpcl", name=None + ): super().__init__() self._compute_on_mask = compute_on_mask self._mask_type = mask_type self._loss_type = loss_type + self._name = "dpcl" if name is None else name @property def compute_on_mask(self) -> bool: @@ -242,7 +245,7 @@ def mask_type(self) -> str: @property def name(self) -> str: - return "dpcl" + return self._name def forward(self, ref, inf) -> torch.Tensor: """time-frequency Deep Clustering loss. diff --git a/test/espnet2/enh/loss/criterions/test_dpcl_loss.py b/test/espnet2/enh/loss/criterions/test_dpcl_loss.py deleted file mode 100644 index 50f04776013..00000000000 --- a/test/espnet2/enh/loss/criterions/test_dpcl_loss.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest - -import torch -from torch_complex import ComplexTensor - -from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainDPCL - - -@pytest.mark.parametrize("criterion_class", [FrequencyDomainDPCL]) -@pytest.mark.parametrize("mask_type", ["IBM"]) -@pytest.mark.parametrize("compute_on_mask", [False]) -@pytest.mark.parametrize("loss_type", ["dpcl", "mdc"]) -def test_tf_domain_criterion_forward( - criterion_class, mask_type, compute_on_mask, loss_type -): - - criterion = criterion_class( - compute_on_mask=compute_on_mask, mask_type=mask_type, loss_type=loss_type - ) - - batch = 2 - inf = torch.rand(batch, 10 * 200, 40) - ref_spec = [ - ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), - ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), - ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), - ] - - ref = [abs(r) for r in ref_spec] - - loss = criterion(ref, inf) - assert loss.shape == (batch,) diff --git a/test/espnet2/enh/loss/criterions/test_tf_domain.py b/test/espnet2/enh/loss/criterions/test_tf_domain.py index 41999948887..9d1cec94a1d 100644 --- a/test/espnet2/enh/loss/criterions/test_tf_domain.py +++ b/test/espnet2/enh/loss/criterions/test_tf_domain.py @@ -6,6 +6,7 @@ from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainAbsCoherence from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainCrossEntropy +from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainDPCL from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainL1 from espnet2.enh.loss.criterions.tf_domain import FrequencyDomainMSE @@ -93,3 +94,22 @@ def test_tf_ce_criterion_forward(input_ch): loss = criterion(ref_spec, inf_spec) assert loss.shape == (batch,), "Invlid loss shape with " + criterion.name + + +@pytest.mark.parametrize("loss_type", ["dpcl", "mdc"]) +def test_tf_dpcl_loss_criterion_forward(loss_type): + + criterion = FrequencyDomainDPCL(loss_type=loss_type) + + batch = 2 + inf = torch.rand(batch, 10 * 200, 40) + ref_spec = [ + ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), + ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), + ComplexTensor(torch.rand(batch, 10, 200), torch.rand(batch, 10, 200)), + ] + + ref = [abs(r) for r in ref_spec] + + loss = criterion(ref, inf) + assert loss.shape == (batch,), "Invlid loss shape with " + criterion.name