Skip to content

Commit

Permalink
remove fused_log_softmax option
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Jul 1, 2021
1 parent 27690ec commit 8944d59
Show file tree
Hide file tree
Showing 13 changed files with 32 additions and 136 deletions.
1 change: 0 additions & 1 deletion test/torchaudio_unittest/rnnt/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_rnnt_loss_gradcheck(self, data_func):
data["target_lengths"], # target_lengths
data["blank"], # blank
-1, # clamp
True, # fused_log_softmax
)

self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
Expand Down
17 changes: 0 additions & 17 deletions test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .utils import (
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_B1_T10_U3_D4_data,
get_data_basic,
get_numpy_data_B1_T2_U3_D5,
get_numpy_data_B2_T4_U3_D3,
Expand Down Expand Up @@ -95,19 +94,3 @@ def test_costs_and_gradients_random_data_with_numpy_fp32(self):
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)

def test_rnnt_nonfused_log_softmax(self):
for random in [False, True]:
data = get_B1_T10_U3_D4_data(
random=random,
)
data = numpy_to_torch(
data=data, device=self.device, requires_grad=True
)
data["fused_log_softmax"] = False
ref_costs, ref_gradients = compute_with_numpy_transducer(
data=data
)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
1 change: 0 additions & 1 deletion test/torchaudio_unittest/rnnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def compute_with_numpy_transducer(data):
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(
blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True),
reduction="none",
)(
logits=data["logits"],
Expand Down
24 changes: 5 additions & 19 deletions torchaudio/csrc/rnnt/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
auto result =
rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
Expand All @@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
logits, targets, logit_lengths, target_lengths, blank, clamp);
return std::make_tuple(results[0], results[1]);
}

Expand Down
15 changes: 3 additions & 12 deletions torchaudio/csrc/rnnt/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
Expand All @@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
"float clamp) -> (Tensor, Tensor?)");
}
3 changes: 1 addition & 2 deletions torchaudio/csrc/rnnt/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax);
double clamp);
4 changes: 1 addition & 3 deletions torchaudio/csrc/rnnt/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;

CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
Expand Down
4 changes: 1 addition & 3 deletions torchaudio/csrc/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;

CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
Expand Down
19 changes: 3 additions & 16 deletions torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
Expand All @@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];

if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]);
}

if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];

if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]);
}
}
}

Expand Down Expand Up @@ -330,8 +319,7 @@ __global__ void ComputeGradients(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
Expand All @@ -353,8 +341,7 @@ __global__ void ComputeGradients(
alphas,
betas,
gradients,
H,
fusedLogSmax);
H);
}

// This is a __global__ wrapper around ComputeAlphas
Expand Down
8 changes: 2 additions & 6 deletions torchaudio/csrc/rnnt/gpu/gpu_transducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ status_t Compute(
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;

const bool& fusedLogSmax = options.fusedLogSmax_;

{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
Expand Down Expand Up @@ -134,8 +132,7 @@ status_t Compute(
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H,
fusedLogSmax);
H);

if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
Expand Down Expand Up @@ -200,8 +197,7 @@ status_t Compute(
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H,
fusedLogSmax);
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
Expand Down
53 changes: 15 additions & 38 deletions torchaudio/csrc/rnnt/gpu/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
Expand Down Expand Up @@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement(
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;

if (fusedLogSmax) {
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else { // Non fused log softmax case
CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]);
if (d == blank && t == T - 1 && u == U - 1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u];
} else if (t < T - 1 && d == blank) {
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]);
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}

if (clamp > 0) {
Expand Down
9 changes: 1 addition & 8 deletions torchaudio/csrc/rnnt/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ typedef struct Options {
// num_targets = D.
int numTargets_;

// if set to true, inputs are logits and gradients are
// fused with logsoftmax gradients.
// if set to false, log_softmax is computed outside of loss
// True by default
bool fusedLogSmax_;

Options()
: device_(UNDEFINED),
numThreads_(0),
Expand All @@ -58,8 +52,7 @@ typedef struct Options {
nHypos_(1),
maxSrcLen_(0),
maxTgtLen_(0),
numTargets_(0),
fusedLogSmax_(true) {}
numTargets_(0) {}

int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_;
Expand Down
10 changes: 0 additions & 10 deletions torchaudio/prototype/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def rnnt_loss(
target_lengths: Tensor,
blank: int = -1,
clamp: float = -1,
fused_log_softmax: bool = True,
reduction: str = "mean",
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
Expand All @@ -31,7 +30,6 @@ def rnnt_loss(
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Expand All @@ -42,9 +40,6 @@ def rnnt_loss(
if reduction not in ['none', 'mean', 'sum']:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")

if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)

if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank

Expand All @@ -55,7 +50,6 @@ def rnnt_loss(
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_softmax=fused_log_softmax
)

if reduction == 'mean':
Expand All @@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module):
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
Expand All @@ -86,13 +79,11 @@ def __init__(
self,
blank: int = -1,
clamp: float = -1.,
fused_log_softmax: bool = True,
reduction: str = "mean",
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.fused_log_softmax = fused_log_softmax
self.reduction = reduction

def forward(
Expand Down Expand Up @@ -120,6 +111,5 @@ def forward(
target_lengths,
self.blank,
self.clamp,
self.fused_log_softmax,
self.reduction
)

0 comments on commit 8944d59

Please sign in to comment.