Skip to content

Commit

Permalink
remove unncessary code
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Jul 9, 2021
1 parent 1398197 commit 326e34d
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,22 @@ def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
with self.subTest():
costs, gradients = compute_with_pytorch_transducer(data=data)
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
for b in range(len(gradients)):
T = data["logit_lengths"][b]
U = data["target_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)
costs, gradients = compute_with_pytorch_transducer(data=data)
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
for b in range(len(gradients)):
T = data["logit_lengths"][b]
U = data["target_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)

def test_basic_backward(self):
rnnt_loss = RNNTLoss()
Expand Down

0 comments on commit 326e34d

Please sign in to comment.