Skip to content

Commit

Permalink
update unittests per PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Aug 18, 2021
1 parent 7da82d5 commit dae3de5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 43 deletions.
16 changes: 0 additions & 16 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,6 @@
load_params,
nested_params
)
from .rnnt_utils import (
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_basic_data,
get_B1_T10_U3_D4_data,
get_B2_T4_U3_D3_data,
get_B1_T2_U3_D5_data,
get_random_data,
)

__all__ = [
'get_asset_path',
Expand All @@ -66,11 +57,4 @@
'save_wav',
'load_params',
'nested_params',
'compute_with_numpy_transducer',
'compute_with_pytorch_transducer',
'get_basic_data',
'get_B1_T10_U3_D4_data',
'get_B2_T4_U3_D3_data',
'get_B1_T2_U3_D5_data',
'get_random_data',
]
10 changes: 4 additions & 6 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_B1_T10_U3_D4_data,
get_B2_T4_U3_D3_data,
get_B1_T2_U3_D5_data,
rnnt_utils,
)


Expand Down Expand Up @@ -215,9 +213,9 @@ def assert_grad(
assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)

@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
(rnnt_utils.get_B1_T10_U3_D4_data, ),
(rnnt_utils.get_B2_T4_U3_D3_data, ),
(rnnt_utils.get_B1_T2_U3_D5_data, ),
])
def test_rnnt_loss(self, data_func):
def get_data(data_func, device):
Expand Down
31 changes: 16 additions & 15 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
get_sinusoid,
nested_params,
get_whitenoise,
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_basic_data,
get_B1_T2_U3_D5_data,
get_B2_T4_U3_D3_data,
get_random_data,
rnnt_utils,
)


Expand Down Expand Up @@ -57,7 +52,7 @@ def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
costs, gradients = compute_with_pytorch_transducer(data=data)
costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
Expand Down Expand Up @@ -457,20 +452,26 @@ def test_pitch_shift_shape(self, n_steps):
assert waveform.size() == waveform_shift.size()

def test_rnnt_loss_basic_backward(self):
logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward()

def test_rnnt_loss_basic_forward_no_grad(self):
logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
"""In early stage, calls to `rnnt_loss` resulted in segmentation fault when
`logits` have `requires_grad = False`. This test makes sure that this no longer
occurs and the functional call runs without error.
See https://github.com/pytorch/audio/pull/1707
"""
logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
logits.requires_grad_(False)
F.rnnt_loss(logits, targets, logit_lengths, target_lengths)

@parameterized.expand([
(get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2),
(get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2),
(get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2),
(get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2),
(rnnt_utils.get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2),
(rnnt_utils.get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2),
(rnnt_utils.get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2),
(rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2),
])
def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
data, ref_costs, ref_gradients = data_func(
Expand All @@ -488,8 +489,8 @@ def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
seed = 777
for i in range(5):
data = get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
ref_costs, ref_gradients = compute_with_numpy_transducer(data=data)
data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
Expand Down
10 changes: 4 additions & 6 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
get_whitenoise,
get_spectrogram,
nested_params,
get_B1_T10_U3_D4_data,
get_B2_T4_U3_D3_data,
get_B1_T2_U3_D5_data,
rnnt_utils,
)


Expand Down Expand Up @@ -280,9 +278,9 @@ def assert_grad(
assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)

@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
(rnnt_utils.get_B1_T10_U3_D4_data, ),
(rnnt_utils.get_B2_T4_U3_D3_data, ),
(rnnt_utils.get_B1_T2_U3_D5_data, ),
])
def test_rnnt_loss(self, data_func):
def get_data(data_func, device):
Expand Down

0 comments on commit dae3de5

Please sign in to comment.