Skip to content

Commit

Permalink
Merge pull request #5 from pkuyym/fix-gpu-numerical
Browse files Browse the repository at this point in the history
Add truncate kernel to avoid occuring 0 in probs_
  • Loading branch information
gangliao authored Aug 1, 2017
2 parents 430e949 + 054b67b commit 02b25dd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
7 changes: 0 additions & 7 deletions include/detail/ctc_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ template <typename Arg, typename Res = Arg> struct exponential {
HOSTDEVICE Res operator()(const Arg& x) const {return std::exp(x);}
};

template <typename Arg, typename Res = Arg> struct clipped_exponential {
HOSTDEVICE Res operator()(const Arg& x) const {
Arg min_Arg = std::numeric_limits<Arg>::min();
return maximum<Res>()(min_Arg, std::exp(x));
}
};

template<typename Arg1, typename Arg2 = Arg1, typename Res=Arg1>
struct log_plus {
typedef Res result_type;
Expand Down
5 changes: 4 additions & 1 deletion include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,12 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {

// Kernel launch to calculate probabilities
compute_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
(ctc_helper::clipped_exponential<ProbT>(), probs_,
(ctc_helper::exponential<ProbT>(), probs_,
denoms_, out_dim_, num_elements);

truncate_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
(probs_, num_elements);

return CTC_STATUS_SUCCESS;
}

Expand Down
21 changes: 19 additions & 2 deletions include/detail/gpu_ctc_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ template<typename ProbT, int NT, int VT>
__global__
void compute_alpha_kernel (const ProbT* probs, const int *label_sizes,
const int *utt_length, const int *repeats_in_labels,
const int *labels_without_blanks, const int *label_offsets,
int *labels_with_blanks, ProbT *alphas,
const int *labels_without_blanks, const int *label_offsets,
int *labels_with_blanks, ProbT *alphas,
ProbT* nll_forward, int stride, int out_dim,
int S_memoffset, int T_memoffset, int blank_label) {

Expand Down Expand Up @@ -469,6 +469,23 @@ __global__ void compute_probs_kernel(Op f, ProbT* probs,
}
}

template <typename ProbT, int VT = 1>
__global__ void truncate_probs_kernel(ProbT* probs, int count) {

int idx = blockDim.x * blockIdx.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
ProbT min_T = numeric_limits<ProbT>::min();
#pragma unroll
for(int i = 0; i < VT; i++) {
if (idx < count) {
if (min_T > probs[idx]) {
probs[idx] = min_T;
}
}
idx += stride;
}
}

template <typename ProbT, int VT = 1, typename Op>
__global__ void prepare_stable_SM_kernel(Op f, ProbT* probs,
const ProbT* const col_max,
Expand Down

0 comments on commit 02b25dd

Please sign in to comment.