Skip to content

Commit

Permalink
[MXNET-807] Support integer label type in ctc_loss operator (apache#1…
Browse files Browse the repository at this point in the history
…2468)

* Support integer type in ctc_loss

* Support any data type in ctc_loss operator

* Enable integer type in labels and fix lint errors

* Fix compilation error in GPU

* Add unit tests

* Undo indentation

* Undo blank line

* Undo blank line

* Add unit test for large number of classes

* move unit tests to test_operator.py per reviewer advice

* update unit test

* update unit test

* update unit test using random seed

* Update unit test

* Fix unit test difference Python2 and Python3
  • Loading branch information
apeforest authored and anirudh2290 committed Sep 19, 2018
1 parent 53dc7b9 commit 9f6eca7
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 68 deletions.
152 changes: 84 additions & 68 deletions src/operator/contrib/ctc_loss-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,66 +256,69 @@ class CTCLossOp : public Operator {
exceed_cudnn_limit = false;
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 3, real_t> data =
MSHADOW_TYPE_SWITCH(in_data[ctc_loss::kLabel].type_flag_, DType, {
Tensor<xpu, 3, real_t> data =
in_data[ctc_loss::kData].get<xpu, 3, real_t>(s);
Tensor<xpu, 2, real_t> labels =
in_data[ctc_loss::kLabel].get<xpu, 2, real_t>(s);
Tensor<xpu, 2, DType> labels =
in_data[ctc_loss::kLabel].get<xpu, 2, DType>(s);

Tensor<xpu, 1, real_t> costs =
Tensor<xpu, 1, real_t> costs =
out_data[ctc_loss::kOut].get<xpu, 1, real_t>(s);
Tensor<xpu, 3, real_t> grad =
Tensor<xpu, 3, real_t> grad =
out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s);

int max_seq_len = data.size(0);
int batch_size = data.size(1);
int alphabet_size = data.size(2);

// data_lengths
std::vector<int> data_lengths(batch_size, max_seq_len);
if (param_.use_data_lengths) {
int kInputLength = 2;
IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
}

// label_lengths
std::vector<int> packed_labels;
std::vector<int> label_lengths(batch_size);

if (param_.use_label_lengths) {
int kLabelLength = 2+param_.use_data_lengths;
exceed_cudnn_limit = PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, real_t>(s),
&packed_labels, &label_lengths);
} else {
exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0?0:-1,
&packed_labels, &label_lengths);
}

// CUDNN is disabled due to lack of support for input lengths
/* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
/* if (!exceed_cudnn_limit) { */
/* cudnn_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* max_seq_len, batch_size, alphabet_size, */
/* req[ctc_loss::kGrad] != mxnet::kNullOp); */
/* } else { */
/* baidu_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp); */
/* } */
/* #else */

baidu_forward(ctx, s, data, costs, grad,
&data_lengths, &label_lengths, &packed_labels,
batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);

if (param_.use_data_lengths) {
// baidu warp CTC implementation sometimes includes undefined gradients
// for data outside of length mask. Setting to 0 to make it consistent
// with CPU implementation.
int kInputLength = 2;
mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
static_cast<real_t>(0));
}
int max_seq_len = data.size(0);
int batch_size = data.size(1);
int alphabet_size = data.size(2);

// data_lengths
std::vector<int> data_lengths(batch_size, max_seq_len);
if (param_.use_data_lengths) {
int kInputLength = 2;
IndexTensorToVector(in_data[kInputLength].get<xpu, 1, real_t>(s), &data_lengths);
}

// label_lengths
std::vector<int> packed_labels;
std::vector<int> label_lengths(batch_size);

if (param_.use_label_lengths) {
int kLabelLength = 2 + param_.use_data_lengths;
exceed_cudnn_limit =
PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, DType>(s),
&packed_labels, &label_lengths);
} else {
exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0 ? 0 : -1,
&packed_labels, &label_lengths);
}

// CUDNN is disabled due to lack of support for input lengths
/* #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 */
/* if (!exceed_cudnn_limit) { */
/* cudnn_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* max_seq_len, batch_size, alphabet_size, */
/* req[ctc_loss::kGrad] != mxnet::kNullOp); */
/* } else { */
/* baidu_forward(ctx, s, data, costs, grad, */
/* &data_lengths, &label_lengths, &packed_labels, */
/* batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);*/
/* } */
/* #else */

baidu_forward(ctx, s, data, costs, grad,
&data_lengths, &label_lengths, &packed_labels,
batch_size, alphabet_size, req[ctc_loss::kGrad] != mxnet::kNullOp);

if (param_.use_data_lengths) {
// baidu warp CTC implementation sometimes includes undefined gradients
// for data outside of length mask. Setting to 0 to make it consistent
// with CPU implementation.
int kInputLength = 2;
mxnet_op::SequenceMask(grad, in_data[kInputLength].get<xpu, 1, real_t>(s),
static_cast<real_t>(0));
}
});
}

virtual void Backward(const OpContext &ctx,
Expand Down Expand Up @@ -434,17 +437,17 @@ class CTCLossOp : public Operator {
}
#endif // __CUDACC__ && CUDNN

inline virtual void baidu_forward(const OpContext &ctx,
mshadow::Stream<xpu>* s,
mshadow::Tensor<xpu, 3, real_t> data,
mshadow::Tensor<xpu, 1, real_t> costs,
mshadow::Tensor<xpu, 3, real_t> grad,
std::vector<int>* data_lengths,
std::vector<int>* label_lengths,
std::vector<int>* packed_labels,
int batch_size,
int alphabet_size,
bool req_grad) {
inline void baidu_forward(const OpContext &ctx,
mshadow::Stream<xpu>* s,
mshadow::Tensor<xpu, 3, real_t> data,
mshadow::Tensor<xpu, 1, real_t> costs,
mshadow::Tensor<xpu, 3, real_t> grad,
std::vector<int>* data_lengths,
std::vector<int>* label_lengths,
std::vector<int>* packed_labels,
int batch_size,
int alphabet_size,
bool req_grad) {
using namespace mshadow;
// allocate temporary workspace
size_t size_bytes;
Expand All @@ -461,7 +464,7 @@ class CTCLossOp : public Operator {
compute_ctc_cost(data, costs.dptr_, grad.dptr_, packed_labels->data(),
label_lengths->data(), data_lengths->data(),
workspace.dptr_, req_grad,
param_.blank_label == 0?0:(alphabet_size-1));
param_.blank_label == 0 ? 0 : (alphabet_size-1));
}
}; // class CTCLossOp

Expand Down Expand Up @@ -534,11 +537,24 @@ class CTCLossProp : public OperatorProperty {
TShape oshape(1);
oshape[0] = dshape[1]; // batch size
out_shape->clear();
out_shape->push_back(oshape);
out_shape->push_back(oshape); // forward output
out_shape->push_back(dshape); // grad output
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_LE(in_type->size(), this->ListArguments().size());
int dtype = (*in_type)[ctc_loss::kData];
CHECK_NE(dtype, -1) << "Input data must have specified type";

out_type->clear();
out_type->push_back(dtype); // forward output
out_type->push_back(dtype); // grad output
return true;
}

OperatorProperty *Copy() const override {
auto ptr = new CTCLossProp();
ptr->param_ = param_;
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def assert_match(inputs, x, y, threshold, is_ascend=False):
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False)
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True)


if __name__ == '__main__':
import nose
nose.runmodule()
24 changes: 24 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4516,6 +4516,30 @@ def test_ctc_loss():
true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
check_ctc_loss(acts2, labels2, true_loss)

# Test 3: check use integer type as label
labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32)
true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch
check_ctc_loss(acts2, labels3, true_loss)

@with_seed()
def test_ctc_loss_with_large_classes():
ctx = default_context()
num_classes = 6000
seq_len = 8
batch_size = 2
data = np.empty((num_classes, 0))
for i in range(seq_len * batch_size) :
row = np.roll(np.arange(num_classes, dtype=np.float32), i).reshape(num_classes, 1)
data = np.append(data, row/13, axis=1)
data = data.reshape(seq_len, batch_size, num_classes)
label = np.array([
[100, 200, 300, 400, 500, 0, 0, 0],
[1000, 2000, 3000, 4000, 0, 5000, 0, 0]], dtype=np.int32)
nd_data = mx.nd.array(data)
nd_label = mx.nd.array(label)
loss = mx.nd.contrib.ctc_loss(data=nd_data, label=nd_label)
expected_loss = np.array([688.02826, 145.34462])
assert_almost_equal(loss.asnumpy(), expected_loss)

@with_seed()
def test_ctc_loss_grad():
Expand Down

0 comments on commit 9f6eca7

Please sign in to comment.