Skip to content

Commit

Permalink
Fix mxnet ctc_loss bug (apache#11834)
Browse files Browse the repository at this point in the history
* fix ctc_loss GPU bug

* add blank_label parameter for CTCLoss

* Revert "add blank_label parameter for CTCLoss"

This reverts commit aab11f7.
  • Loading branch information
Mingkun Huang authored and szha committed Jul 27, 2018
1 parent 69fe157 commit 9e251d3
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions src/operator/contrib/ctc_include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,7 @@ GpuCTC<ProbT>::compute_log_probs(const ProbT* const activations) {
denoms_, out_dim_, num_elements);

// compute denominators for softmax
denoms_handle = reduce_with_axis<red::sum, false>(
F<mxnet::op::mshadow_op::exp>(
log_probs_handle -
broadcast<0>(reduce_with_axis<red::maximum, false>(log_probs_handle, 1),
log_probs_handle.shape_)),
1);
denoms_handle = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::exp>(log_probs_handle), 1);

// Kernel launch to calculate probabilities
compute_log_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
Expand Down

0 comments on commit 9e251d3

Please sign in to comment.