diff --git a/include/detail/ctc_helper.h b/include/detail/ctc_helper.h index f1a94f6..29af220 100644 --- a/include/detail/ctc_helper.h +++ b/include/detail/ctc_helper.h @@ -44,13 +44,6 @@ template struct exponential { HOSTDEVICE Res operator()(const Arg& x) const {return std::exp(x);} }; -template struct clipped_exponential { - HOSTDEVICE Res operator()(const Arg& x) const { - Arg min_Arg = std::numeric_limits::min(); - return maximum()(min_Arg, std::exp(x)); - } -}; - template struct log_plus { typedef Res result_type; diff --git a/include/detail/gpu_ctc.h b/include/detail/gpu_ctc.h index 1ead9b7..2149d99 100644 --- a/include/detail/gpu_ctc.h +++ b/include/detail/gpu_ctc.h @@ -392,9 +392,12 @@ GpuCTC::compute_probs(const ProbT* const activations) { // Kernel launch to calculate probabilities compute_probs_kernel<<>> - (ctc_helper::clipped_exponential(), probs_, + (ctc_helper::exponential(), probs_, denoms_, out_dim_, num_elements); + truncate_probs_kernel<<>> + (probs_, num_elements); + return CTC_STATUS_SUCCESS; } diff --git a/include/detail/gpu_ctc_kernels.h b/include/detail/gpu_ctc_kernels.h index cf6dba9..07412d0 100644 --- a/include/detail/gpu_ctc_kernels.h +++ b/include/detail/gpu_ctc_kernels.h @@ -88,8 +88,8 @@ template __global__ void compute_alpha_kernel (const ProbT* probs, const int *label_sizes, const int *utt_length, const int *repeats_in_labels, - const int *labels_without_blanks, const int *label_offsets, - int *labels_with_blanks, ProbT *alphas, + const int *labels_without_blanks, const int *label_offsets, + int *labels_with_blanks, ProbT *alphas, ProbT* nll_forward, int stride, int out_dim, int S_memoffset, int T_memoffset, int blank_label) { @@ -469,6 +469,23 @@ __global__ void compute_probs_kernel(Op f, ProbT* probs, } } +template +__global__ void truncate_probs_kernel(ProbT* probs, int count) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + ProbT min_T = numeric_limits::min(); +#pragma unroll + for(int i = 0; i < VT; i++) { + if (idx < count) { + if (min_T > probs[idx]) { + probs[idx] = min_T; + } + } + idx += stride; + } +} + template __global__ void prepare_stable_SM_kernel(Op f, ProbT* probs, const ProbT* const col_max,