Skip to content

Commit

Permalink
Remove reuse_logits_for_grads option for RNNTL (#1610)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen authored Aug 3, 2021
1 parent 25ceee7 commit 16f3b2f
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 53 deletions.
3 changes: 1 addition & 2 deletions test/torchaudio_unittest/rnnt/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_RNNTLoss_gradcheck(self, data_func):
data["logit_lengths"],
data["target_lengths"],
)
loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False)
loss = RNNTLoss(blank=data["blank"])

self.assert_grad(loss, inputs, enable_all_grad=False)

Expand All @@ -72,7 +72,6 @@ def test_rnnt_loss_gradcheck(self, data_func):
data["blank"], # blank
-1, # clamp
True, # fused_log_softmax
False, # reuse_logits_for_grads
)

self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
Expand Down
12 changes: 4 additions & 8 deletions test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@ def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
for reuse_logits_for_grads in [False, True]:
with self.subTest(reuse_logits_for_grads=reuse_logits_for_grads):
costs, gradients = compute_with_pytorch_transducer(
data=data, reuse_logits_for_grads=reuse_logits_for_grads
)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
costs, gradients = compute_with_pytorch_transducer(data=data)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)

def test_basic_backward(self):
rnnt_loss = RNNTLoss()
Expand Down
3 changes: 1 addition & 2 deletions test/torchaudio_unittest/rnnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ def compute_with_numpy_transducer(data):
return costs, gradients


def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(
blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True),
reuse_logits_for_grads=reuse_logits_for_grads,
reduction="none",
)(
logits=data["logits"],
Expand Down
12 changes: 4 additions & 8 deletions torchaudio/csrc/rnnt/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
torch::Tensor undef;
auto result = rnnt_loss(
logits,
Expand All @@ -24,8 +23,7 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
target_lengths,
blank,
clamp,
fused_log_softmax,
reuse_logits_for_grads);
fused_log_softmax);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
Expand All @@ -51,8 +49,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits,
Expand All @@ -61,8 +58,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
target_lengths,
blank,
clamp,
fused_log_softmax,
reuse_logits_for_grads);
fused_log_softmax);
return std::make_tuple(results[0], results[1]);
}

Expand Down
9 changes: 3 additions & 6 deletions torchaudio/csrc/rnnt/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
Expand All @@ -20,8 +19,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
target_lengths,
blank,
clamp,
fused_log_softmax,
reuse_logits_for_grads);
fused_log_softmax);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
Expand All @@ -32,6 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax=True,"
"bool reuse_logits_for_grads=True) -> (Tensor, Tensor?)");
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
}
3 changes: 1 addition & 2 deletions torchaudio/csrc/rnnt/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax,
bool reuse_logits_for_grads);
bool fused_log_softmax);
9 changes: 2 additions & 7 deletions torchaudio/csrc/rnnt/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -92,11 +91,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
if (reuse_logits_for_grads) {
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
gradients = torch::zeros_like(logits);
}

torch::Tensor int_workspace = torch::empty(
Expand Down
9 changes: 2 additions & 7 deletions torchaudio/csrc/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true,
bool reuse_logits_for_grads = true) {
bool fused_log_softmax = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -95,11 +94,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
c10::optional<torch::Tensor> gradients = c10::nullopt;
if (logits.requires_grad()) {
if (reuse_logits_for_grads) {
gradients = logits;
} else {
gradients = torch::zeros_like(logits);
}
gradients = torch::zeros_like(logits);
}

torch::Tensor int_workspace = torch::empty(
Expand Down
13 changes: 2 additions & 11 deletions torchaudio/prototype/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def rnnt_loss(
blank: int = -1,
clamp: float = -1,
fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
reduction: str = "mean",
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
Expand All @@ -33,7 +32,6 @@ def rnnt_loss(
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Expand All @@ -46,9 +44,6 @@ def rnnt_loss(

if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)
reuse_logits_for_grads = (
False # softmax needs the original logits value
)

if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank
Expand All @@ -60,8 +55,8 @@ def rnnt_loss(
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_softmax=fused_log_softmax,
reuse_logits_for_grads=reuse_logits_for_grads,)
fused_log_softmax=fused_log_softmax
)

if reduction == 'mean':
return costs.mean()
Expand All @@ -83,7 +78,6 @@ class RNNTLoss(torch.nn.Module):
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
Expand All @@ -93,14 +87,12 @@ def __init__(
blank: int = -1,
clamp: float = -1.,
fused_log_softmax: bool = True,
reuse_logits_for_grads: bool = True,
reduction: str = "mean",
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.fused_log_softmax = fused_log_softmax
self.reuse_logits_for_grads = reuse_logits_for_grads
self.reduction = reduction

def forward(
Expand Down Expand Up @@ -129,6 +121,5 @@ def forward(
self.blank,
self.clamp,
self.fused_log_softmax,
self.reuse_logits_for_grads,
self.reduction
)

0 comments on commit 16f3b2f

Please sign in to comment.