diff --git a/src/operator/contrib/ctc_loss-inl.h b/src/operator/contrib/ctc_loss-inl.h index 9380be47451f..c8a8b2637401 100644 --- a/src/operator/contrib/ctc_loss-inl.h +++ b/src/operator/contrib/ctc_loss-inl.h @@ -256,66 +256,69 @@ class CTCLossOp : public Operator { exceed_cudnn_limit = false; Stream *s = ctx.get_stream(); - Tensor data = + MSHADOW_TYPE_SWITCH(in_data[ctc_loss::kLabel].type_flag_, DType, { + Tensor data = in_data[ctc_loss::kData].get(s); - Tensor labels = - in_data[ctc_loss::kLabel].get(s); + Tensor labels = + in_data[ctc_loss::kLabel].get(s); - Tensor costs = + Tensor costs = out_data[ctc_loss::kOut].get(s); - Tensor grad = + Tensor grad = out_data[ctc_loss::kGrad].get(s); - int max_seq_len = data.size(0); - int batch_size = data.size(1); - int alphabet_size = data.size(2); - - // data_lengths - std::vector data_lengths(batch_size, max_seq_len); - if (param_.use_data_lengths) { - int kInputLength = 2; - IndexTensorToVector(in_data[kInputLength].get(s), &data_lengths); - } - - // label_lengths - std::vector packed_labels; - std::vector label_lengths(batch_size); - - if (param_.use_label_lengths) { - int kLabelLength = 2+param_.use_data_lengths; - exceed_cudnn_limit = PackLabelByLength(labels, in_data[kLabelLength].get(s), - &packed_labels, &label_lengths); - } else { - exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0?0:-1, - &packed_labels, &label_lengths); - } - -// CUDNN is disabled due to lack of support for input lengths -/* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */ -/* if (!exceed_cudnn_limit) { */ -/* cudnn_forward(ctx, s, data, costs, grad, */ -/* &data_lengths, &label_lengths, &packed_labels, */ -/* max_seq_len, batch_size, alphabet_size, */ -/* req[ctc_loss::kGrad] != mxnet::kNullOp); */ -/* } else { */ -/* baidu_forward(ctx, s, data, costs, grad, */ -/* &data_lengths, &label_lengths, &packed_labels, */ -/* batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp); */ -/* } */ -/* #else */ - - baidu_forward(ctx, s, data, costs, grad, - &data_lengths, &label_lengths, &packed_labels, - batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp); - - if (param_.use_data_lengths) { - // baidu warp CTC implementation sometimes includes undefined gradients - // for data outside of length mask. Setting to 0 to make it consistent - // with CPU implementation. - int kInputLength = 2; - mxnet_op::SequenceMask(grad, in_data[kInputLength].get(s), - static_cast(0)); - } + int max_seq_len = data.size(0); + int batch_size = data.size(1); + int alphabet_size = data.size(2); + + // data_lengths + std::vector data_lengths(batch_size, max_seq_len); + if (param_.use_data_lengths) { + int kInputLength = 2; + IndexTensorToVector(in_data[kInputLength].get(s), &data_lengths); + } + + // label_lengths + std::vector packed_labels; + std::vector label_lengths(batch_size); + + if (param_.use_label_lengths) { + int kLabelLength = 2 + param_.use_data_lengths; + exceed_cudnn_limit = + PackLabelByLength(labels, in_data[kLabelLength].get(s), + &packed_labels, &label_lengths); + } else { + exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0 ? 0 : -1, + &packed_labels, &label_lengths); + } + + // CUDNN is disabled due to lack of support for input lengths + /* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */ + /* if (!exceed_cudnn_limit) { */ + /* cudnn_forward(ctx, s, data, costs, grad, */ + /* &data_lengths, &label_lengths, &packed_labels, */ + /* max_seq_len, batch_size, alphabet_size, */ + /* req[ctc_loss::kGrad] != mxnet::kNullOp); */ + /* } else { */ + /* baidu_forward(ctx, s, data, costs, grad, */ + /* &data_lengths, &label_lengths, &packed_labels, */ + /* batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);*/ + /* } */ + /* #else */ + + baidu_forward(ctx, s, data, costs, grad, + &data_lengths, &label_lengths, &packed_labels, + batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp); + + if (param_.use_data_lengths) { + // baidu warp CTC implementation sometimes includes undefined gradients + // for data outside of length mask. Setting to 0 to make it consistent + // with CPU implementation. + int kInputLength = 2; + mxnet_op::SequenceMask(grad, in_data[kInputLength].get(s), + static_cast(0)); + } + }); } virtual void Backward(const OpContext &ctx, @@ -434,17 +437,17 @@ class CTCLossOp : public Operator { } #endif // __CUDACC__ && CUDNN - inline virtual void baidu_forward(const OpContext &ctx, - mshadow::Stream* s, - mshadow::Tensor data, - mshadow::Tensor costs, - mshadow::Tensor grad, - std::vector* data_lengths, - std::vector* label_lengths, - std::vector* packed_labels, - int batch_size, - int alphabet_size, - bool req_grad) { + inline void baidu_forward(const OpContext &ctx, + mshadow::Stream* s, + mshadow::Tensor data, + mshadow::Tensor costs, + mshadow::Tensor grad, + std::vector* data_lengths, + std::vector* label_lengths, + std::vector* packed_labels, + int batch_size, + int alphabet_size, + bool req_grad) { using namespace mshadow; // allocate temporary workspace size_t size_bytes; @@ -461,7 +464,7 @@ class CTCLossOp : public Operator { compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(), label_lengths->data(), data_lengths->data(), workspace.dptr_, req_grad, - param_.blank_label == 0?0:(alphabet_size-1)); + param_.blank_label == 0 ? 0 : (alphabet_size-1)); } }; // class CTCLossOp @@ -534,11 +537,24 @@ class CTCLossProp : public OperatorProperty { TShape oshape(1); oshape[0] = dshape[1]; // batch size out_shape->clear(); - out_shape->push_back(oshape); + out_shape->push_back(oshape); // forward output out_shape->push_back(dshape); // grad output return true; } + bool InferType(std::vector *in_type, + std::vector *out_type, + std::vector *aux_type) const override { + CHECK_LE(in_type->size(), this->ListArguments().size()); + int dtype = (*in_type)[ctc_loss::kData]; + CHECK_NE(dtype, -1) << "Input data must have specified type"; + + out_type->clear(); + out_type->push_back(dtype); // forward output + out_type->push_back(dtype); // grad output + return true; + } + OperatorProperty *Copy() const override { auto ptr = new CTCLossProp(); ptr->param_ = param_; diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index fc6c1be9c3a1..76efe305bceb 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -244,6 +244,7 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 9842a69e18d4..4ec4bf1b384f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4516,6 +4516,30 @@ def test_ctc_loss(): true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch check_ctc_loss(acts2, labels2, true_loss) + # Test 3: check use integer type as label + labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32) + true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch + check_ctc_loss(acts2, labels3, true_loss) + +@with_seed() +def test_ctc_loss_with_large_classes(): + ctx = default_context() + num_classes = 6000 + seq_len = 8 + batch_size = 2 + data = np.empty((num_classes, 0)) + for i in range(seq_len * batch_size) : + row = np.roll(np.arange(num_classes, dtype=np.float32), i).reshape(num_classes, 1) + data = np.append(data, row/13, axis=1) + data = data.reshape(seq_len, batch_size, num_classes) + label = np.array([ + [100, 200, 300, 400, 500, 0, 0, 0], + [1000, 2000, 3000, 4000, 0, 5000, 0, 0]], dtype=np.int32) + nd_data = mx.nd.array(data) + nd_label = mx.nd.array(label) + loss = mx.nd.contrib.ctc_loss(data=nd_data, label=nd_label) + expected_loss = np.array([688.02826, 145.34462]) + assert_almost_equal(loss.asnumpy(), expected_loss) @with_seed() def test_ctc_loss_grad():