Skip to content

Commit

Permalink
Merge pull request #4 from pkuyym/fix-gpu-numerical
Browse files Browse the repository at this point in the history
Fix numerical problem for gpu version.
  • Loading branch information
gangliao authored Aug 1, 2017
2 parents 0a60ee5 + 10ab30b commit 430e949
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions include/detail/ctc_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ template <typename Arg, typename Res = Arg> struct exponential {
HOSTDEVICE Res operator()(const Arg& x) const {return std::exp(x);}
};

template <typename Arg, typename Res = Arg> struct clipped_exponential {
HOSTDEVICE Res operator()(const Arg& x) const {
Arg min_Arg = std::numeric_limits<Arg>::min();
return maximum<Res>()(min_Arg, std::exp(x));
}
};

template<typename Arg1, typename Arg2 = Arg1, typename Res=Arg1>
struct log_plus {
typedef Res result_type;
Expand Down
2 changes: 1 addition & 1 deletion include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {

// Kernel launch to calculate probabilities
compute_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
(ctc_helper::exponential<ProbT>(), probs_,
(ctc_helper::clipped_exponential<ProbT>(), probs_,
denoms_, out_dim_, num_elements);

return CTC_STATUS_SUCCESS;
Expand Down

0 comments on commit 430e949

Please sign in to comment.