-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-807] Support integer label type in ctc_loss operator #12468
Conversation
src/operator/contrib/ctc_loss-inl.h
Outdated
PackLabelByLength(labels, in_data[kLabelLength].get<xpu, 1, DType>(s), | ||
&packed_labels, &label_lengths); | ||
} else { | ||
exceed_cudnn_limit = LabelTensorToPackedVector(labels, param_.blank_label == 0?0:-1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some formatting issues with whitespaces and indentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, what was exactly the issue? The make lint seems to pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, should be 0 ? 0 : -1
@apeforest Could you rebase your change on latest master? |
@apeforest If the ctc_loss operator is complete, we should consider moving this to operator/nn and out of contrib. |
@lebeg @samskalicky I have merged from master and the check now passed. Please review the PR again. Thanks |
Hi @szha, @samskalicky asks if we could move this operator from contrib to mxnet regular. Do you have any suggestion? Thanks! |
@Jerryzcn @zhiheng-huang you're likely interested in this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please get rid of all indentation changes.
@szha The extra indentation in the block was due to the addition of macro MSHADOW_TYPE_SWITCH on line 258. Other places were due to make lint failure. |
src/operator/contrib/ctc_loss-inl.h
Outdated
enum CTCLossOpForwardResource { kTempSpace }; | ||
enum CTCLossOpInputs { kData, kLabel }; | ||
enum CTCLossOpOutputs { kOut, kGrad }; | ||
enum CTCLossOpForwardResource { kTempSpace }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which case does this part fall into? how were we able to check it in without breaking master build if it's either of the cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, overlooked these lines. I have removed them.
src/operator/contrib/ctc_loss-inl.h
Outdated
|
||
Tensor<xpu, 3, real_t> data_grad_computed = | ||
out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s); | ||
out_data[ctc_loss::kGrad].get<xpu, 3, real_t>(s); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. However, these 4 space indentation seems a violation of Google C++ style guide. Are we ignoring them in the lint?
https://google.github.io/styleguide/cppguide.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, there already are tests.
@@ -244,6 +244,64 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): | |||
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) | |||
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) | |||
|
|||
def test_ctc_loss_op(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CTC loss tests can be found at https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_operator.py#L4500, and integration at https://github.com/apache/incubator-mxnet/blob/master/tests/python/unittest/test_loss.py#L186. Test cases are from hand calculated examples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to add test cases for large labels there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reference. Moving the unit test there.
loss = mx.nd.contrib.ctc_loss(data=data, label=label) | ||
loss = mx.nd.make_loss(loss) | ||
expected_output = [9.604521, 7.096151, 4.906869, 5.5237527, 5.9895644, 5.584548, | ||
5.528411, 5.765914, 6.740701, 5.2625823] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This testing strategy (i.e. compare the output from random input and labels with fixed seed from recorded output) is not meaningful and does not guarantee anything. It merely increases the line coverage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did not notice the unit test in test_operator.py. I have removed this one.
Also, since this is still using legacy op interface, would you mind adopting the new operator interface for this? |
@szha This PR is to fix the unsupported integer label type. If we were to refactor this operator altogether, I would prefer to do it in another PR. What do you think? Also, as @samskalicky suggested, do you think it is mature to move this operator from contrib to regular? If that's the case, we can create another ticket to perform this migration together with refactoring. |
label_len = 10 | ||
num_classes = 6000 | ||
x = np.random.uniform(size=(seq_len, batch_size, num_classes)) | ||
y = np.random.randint(0, num_classes, size=(batch_size, label_len)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again this does not seem like a good way of testing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestion to test the large classes? I could compare this with WarpCtc implementation result if that can be treated as golden.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make a small example, calculate a the value and test for that, like in any other CTC tests. Since this is for testing the type, the batch size and sequence lengths are irrelevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The label type is tested in line 4520. This testcase is to test the large number of classes that would crash reported in issue #10995
@with_seed(1) | ||
def test_ctc_loss_with_large_classes(): | ||
ctx = default_context() | ||
batch_size = 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this help to verify the correctness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simply used the example reported in the original issue to make sure this fix addressed that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue that needs testing is the type of the labels, so a large batch size doesn't seem helpful or necessary for verifying the correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, tests with fixed seed are treated as a test quality issue and are being eliminated right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The label type is tested in line 4520. This testcase is to test the large number of classes that would crash reported in issue #10995
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, then make a test for it. Batch size is still not relevant, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not the batch_size in training. It is the size of the vocabulary. We need this variable to create the 3D tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the variable name and removed the fixed seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it really is not, the vocabulary size, regardless of how you name it. Please check the API doc and see its usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my misunderstanding the API. I have updated the unit tests based on your suggestion. Please review it again. Thanks!
data = mx.nd.array(x, ctx=ctx) | ||
label = mx.nd.array(y, ctx=ctx) | ||
loss = mx.nd.contrib.ctc_loss(data=data, label=label) | ||
assert loss.asnumpy().shape[0] == m |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you testing for shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to test the operator does not crash upon large number of classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not crash on the master branch without the change either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. This unit test is not to test my fix. It is to test an earlier PR #11834 which did not include a unit test but was merged somehow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for that. Still, the batch size is unnecessarily large. Why not make the test run faster? Also, there's still no test that covers the loss of precision problem that the integer label type solves, which is part of your fix. Would you mind adding that please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the batch size to 2.
Created a ticket https://issues.apache.org/jira/browse/MXNET-912 to move this operator from contrib to regular. |
ctc_loss operator yields different result in python3 and python2 environment breaking the newly added unit test. I am investigating the rootcause. |
…2468) * Support integer type in ctc_loss * Support any data type in ctc_loss operator * Enable integer type in labels and fix lint errors * Fix compilation error in GPU * Add unit tests * Undo indentation * Undo blank line * Undo blank line * Add unit test for large number of classes * move unit tests to test_operator.py per reviewer advice * update unit test * update unit test * update unit test using random seed * Update unit test * Fix unit test difference Python2 and Python3
Description
This PR fixed part of the issues in #10995. It supports integer type in label just as WarpCTC does.
Also added a unit test to test large class issue fixed by another PR #11834
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments