-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-40]add multi proposal operator (cpu version) and fix the bug in proposal op (gpu version) #9939
[MXNET-40]add multi proposal operator (cpu version) and fix the bug in proposal op (gpu version) #9939
Conversation
Thanks for adding the CPU implementation.. Would you like to add test cases for it, so the operator can be checked in jenkins. |
@TaoLv Thank you! |
@@ -553,10 +553,10 @@ class ProposalGPUOp : public Operator{ | |||
cudaMemcpyHostToDevice)); | |||
|
|||
// copy results after nms | |||
dimGrid.x = (rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; | |||
dimGrid.x = (param_.rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should param_.rpn_post_nms_top_n
be used here? Or rpn_post_nms_top_n
computed here originally https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cu#L438 ? What's the bug in the original version? cc @precedenceguo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZiyueHuang If count_anchor
< param_.rpn_post_nms_top_n
,
rpn_post_nms_top_n
will be equal to count_anchor
,
and rpn_post_nms_top_n
will be less than param_.rpn_post_nms_top_n
.
https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal-inl.h#L119
It shows the batch_size
of the output is param_.rpn_post_nms_top_n
instead of rpn_post_nms_top_n
.
The original version will not assign the element whose index is larger than rpn_post_nms_top_n
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Elements whose index is larger than rpn_post_nms_top_n
should be assigned. The values are copied from the valid anchors, just to satisfy the output size requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@precedenceguo Thank you!
I have a question.
The score of the anchor boxes whose size are less than rpn_min_size
will be assigned to -1 (FilterBox), then these anchor boxes may be NMS and selected as the output in the implementation. Are they valid anchors?
I think they are invalid and should be thrown away.
Reference:
https://github.com/precedenceguo/mx-rcnn/blob/master/rcnn/symbol/proposal.py#L116
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they should be thrown out; but that would require a new array. Instead I set the predicted confidence to be -1, same as the boxes whose centers are out of the input image boundary. Those invalid anchors would be ranked to the bottom for NMS so a proper threshold did not produce significant problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
The invalid anchors would be ranked to the bottom for NMS, but the threshold is for overlap between two anchors rather than score.
Should we add a condition that score != -1
in NMS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is a good idea but I did not find significant issue of the invalid anchors.
Could you add a cpu/gpu consistency test in tests/python/gpu/test_operator_gpu.py? |
@piiswrong Yes, I will add it. and Proposal OP (GPU implementation) uses stable sort ( And the accuracy of computation between CPU and GPU is different, it causes that NMS selects different boxes between CPU and GPU. When I replace |
I wrote a cpu/gpu consistency test for Proposal and MultiProposal. It seems that the index order of the Non-Maximum-Suppression result may be different between the CPU implementation and the GPU implementation. Here is the cpu/gpu consistency test. import mxnet as mx
import numpy as np
# @with_seed()
def test_multi_proposal_op():
# paramters
feature_stride = 16
scales = (8, 16, 32)
ratios = (0.5, 1, 2)
rpn_pre_nms_top_n = 12000
rpn_post_nms_top_n = 2000
threshold = 0.7
rpn_min_size = 16
feat_len = 14
H, W = feat_len, feat_len
num_anchors = len(scales) * len(ratios)
count_anchors = H * W * num_anchors
def get_new_data(batch_size, ctx):
'''
cls_prob: (batch_size, 2 * num_anchors, H, W)
bbox_pred: (batch_size, 4 * num_anchors, H, W)
im_info: (batch_size, 3)
'''
cls_prob = mx.nd.empty((batch_size, 2 * num_anchors, H, W), dtype = np.float32, ctx = ctx)
bbox_pred = mx.nd.empty((batch_size, 4 * num_anchors, H, W), dtype = np.float32, ctx = ctx)
im_info = mx.nd.empty((batch_size, 3), dtype = np.float32, ctx = ctx)
cls_prob = mx.nd.array(np.random.random(cls_prob.shape), ctx = ctx)
bbox_pred = mx.nd.array(np.random.random(bbox_pred.shape), ctx = ctx)
for i in range(batch_size):
im_size = np.random.randint(100, feat_len * feature_stride, size = (2,))
im_scale = np.random.randint(70, 100) / 100.0
im_info[i, :] = [im_size[0], im_size[1], im_scale]
return cls_prob, bbox_pred, im_info
def check_proposal_consistency(op, batch_size):
'''
op is mx.nd.contrib.Proposal or mx.nd.contrib.MultiProposal
'''
cls_prob, bbox_pred, im_info = get_new_data(batch_size, mx.cpu(0))
rois_cpu, score_cpu = op(
cls_score = cls_prob,
bbox_pred = bbox_pred,
im_info = im_info,
feature_stride = feature_stride,
scales = scales,
ratios = ratios,
rpn_pre_nms_top_n = rpn_pre_nms_top_n,
rpn_post_nms_top_n = rpn_post_nms_top_n,
threshold = threshold,
rpn_min_size = rpn_min_size, output_score = True)
gpu_ctx = mx.gpu(0)
# copy data to gpu from cpu
cls_prob_gpu = cls_prob.as_in_context(gpu_ctx)
bbox_pred_gpu = bbox_pred.as_in_context(gpu_ctx)
im_info_gpu = im_info.as_in_context(gpu_ctx)
rois_gpu, score_gpu = op(
cls_score = cls_prob_gpu,
bbox_pred = bbox_pred_gpu,
im_info = im_info_gpu,
feature_stride = feature_stride,
scales = scales,
ratios = ratios,
rpn_pre_nms_top_n = rpn_pre_nms_top_n,
rpn_post_nms_top_n = rpn_post_nms_top_n,
threshold = threshold,
rpn_min_size = rpn_min_size, output_score = True)
print (rois_cpu.asnumpy(), rois_gpu.asnumpy())
assert np.allclose(rois_cpu.asnumpy(), rois_gpu.asnumpy())
assert np.allclose(score_cpu.asnumpy(), score_gpu.asnumpy())
check_proposal_consistency(mx.nd.contrib.Proposal, 1)
check_proposal_consistency(mx.nd.contrib.MultiProposal, 20)
test_multi_proposal_op()
print ("test ok") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wkcn Great works :)
I see there're several nested for loop and I think you can parallelize them by OpenMP for the better performance.
Would you mind take a try?
If any help is needed, feel free to let me know.
for (int a = 0; a < anchors; ++a) { | ||
for (int h = 0; h < heights; ++h) { | ||
for (int w = 0; w < widths; ++w) { | ||
index_t index = h * (widths * anchors) + w * (anchors) + a; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to parallel the for-loop by OMP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which one do you think should be parallelized? always the outside one? the one with the largest number of iterations? the smallest number of iterations? what’s the convention on this (i’ve been curious of this myself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cjolivier01 I think we can merge the three for-loop to a for-loop, then use OMP to optimize it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wkcn you don't need. collapse
can work as your expectation :)
https://software.intel.com/en-us/articles/openmp-loop-collapse-directive
Merging the three for-loop will make the code hard to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cjolivier01 In general, we prefer to parallelize in the outer loop because more tasks can be run simultaneously, like task-level parallelization. When we parallelize the inner-loop (not this case), most likely we want to do vectorization by OMP (simd).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengzhao-intel Thank you:) I will read the document and try it.
|
||
for (int a = 0; a < anchors; ++a) { | ||
for (int h = 0; h < heights; ++h) { | ||
for (int w = 0; w < widths; ++w) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to parallel the for-loop by OMP?
|
||
// calculate nms | ||
*out_size = 0; | ||
for (index_t i = 0; i < dets.size(0) && (*out_size) < post_nms_top_n; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to parallel the for-loop by OMP?
@pengzhao-intel Thank you! I will have a try.
|
@wkcn What's the size of your for-loop in the test-case? |
@pengzhao-intel Here is the testing code. I found that I made a mistake that I didn't use
However, when I set the environment variables Update: 2018-03-09 (626296b)
|
@wkcn Nice work for the parallelization by OMP. 2X speedup is a good start point. I see you have added several OMP in the code. PS: you may need to add |
@wkcn @pengzhao-intel I think there are so much units in multi_proposal.cc. We can help find which unit is the bottleneck and then fix it. |
@pengzhao-intel @xinyu-intel
Using stable sort to sort anchors (ReverseArgsort) will increase about 3000 ms. Update: 2018-03-09 (626296b)
|
For the cpu/gpu consistency test, it's interesting that the number of valid anchors in the CPU implementation and the GPU implementation may be different. The margin of the minimal valid anchor: I want to create testing sample to avoid these margins. |
Hi, you can just create an empty commit to trigger the CI yourself :)
JackieWu <notifications@github.com> schrieb am Di., 6. März 2018, 07:49:
… @marcoabreu <https://github.com/marcoabreu>
Hello! Could you please retrigger the test?
It seems that test_operator_gpu.test_deconv has some problem.
And I have removed 'USE_STABLE_SORT_FOR_PROPOSAL' building flag.
Thank you!
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#9939 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/ARxB60WFPEI1M-bpfgc0ROyVoDjozvq5ks5tbjF4gaJpZM4SXzsy>
.
|
@marcoabreu Thank you:) I will commit it soon. |
Hi, the community has passed to vote about associating the code changes with JIRA (https://lists.apache.org/thread.html/ab22cf0e35f1bce2c3bf3bec2bc5b85a9583a3fe7fd56ba1bbade55f@%3Cdev.mxnet.apache.org%3E) We have updated the guidelines for contributors in https://cwiki.apache.org/confluence/display/MXNET/Development+Process, please ensure that you have created a JIRA at https://issues.apache.org/jira/projects/MXNET/issues/ to describe your work in this pull request and include the JIRA title in your PR as [MXNET-xxxx] your title where MXNET-xxxx is the JIRA id Thanks! |
@wkcn thanks for the detail profiling data. Regarding sorting, you can use MKL sort in case the library is set to MKL. It will be much faster than sequential sort. Regarding calculating nms, the current implementation is not very good for the better performance. @xinyu-intel will do several experiments in local and back to you later. |
@pengzhao-intel Thank you! |
@CodingCat I will read it. Thank you! |
@TaoLv @ZiyueHuang @precedenceguo @piiswrong @pengzhao-intel @xinyu-intel @marcoabreu @CodingCat Hello! There are the changes in this PR.
For Changes Requested, I have removed the USE_STABLE_SORT_FOR_PROPOSAL flag and recovered the building setting. Could you please merge the PR? Thank you! :-) |
@marcoabreu @szha @cjolivier01 Hello! For Changes Requested, I have removed the USE_STABLE_SORT_FOR_PROPOSAL flag and recovered the building setting. Could you please review and merge the PR? Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@marcoabreu Thank you! :) |
//===================== | ||
// NMS Utils | ||
//===================== | ||
namespace mxnet { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you check the sort_op in src/operator//tensor/sort_op.h?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhreshold Thank you!
It seems that I can use SortByKey
in src/operator/tensor/sort_op.h, may replace the sort in proposal.cc
and multi_proposal.cc
after that.
However, SortByKey
in src/operator/tensor/sort_op.h uses stable_sort
, which is slightly slower than std::sort
@wkcn Hi looks like the multi proposal operator fails randomly. Do you have any idea? |
@piiswrong The cpu/gpu consistency couldn't be guaranteed when the output has invalid anchors whose score is -1. https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/multi_proposal.cc#L89 I followed Proposal Operator and Multi Proposal (GPU) to implement Multi Proposal (CPU). There are three ways to solve the cpu/gpu consistency test in Proposal and MultiProposal:
I will fix it soon. |
Hello @wkcn, thanks a lot for the quick response! Do you have an ETA when this is going to be resolved? As stated by eric, we're experiencing a significant number of test failures which is impacting the overall results of our CI. |
@marcoabreu Sorry. I'm fixing it. When I finish it, I will create a new PR. Could you disable the cpu/gpu consistency test temporarily? The another solution is to replace Thank you! |
Feel free to create a PR which disables that test, I'll merge it asap. |
@marcoabreu I will fix the bug. |
@piiswrong @marcoabreu I found the reason of the CPU/GPU consistency test failed. When doing Non Maximum Suppression and the overlap of anchors is close to the threshold, the overlap may be greater than the threshold in the CPU implementation, but that is less than the threshold in the GPU implementation. Here is the output (score) comparison between the Multi Proposal CPU/GPU implementations.
|
…n proposal op (gpu version) (apache#9939) * add multi proposal operator (cpu version) * lint fixed * add the unit-test for MultiProposal Operator * fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu) * fix _nms function in the GPU implementation of Proposal and MultiProposal * use openmp to optimize MultiProposal Operator * fix MultiProposal Building on Windows * rm openmp collapse token in MultiProposal.cc * use signed iterator for MultiProposal * add cpu/gpu consistency test for Proposal and MultiProposal * fix test samples in test_multi_proposal_op * add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile * use stable sort for MultiProposal and change Building Settings back * skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator * add unittest skip for test_multi_proposal_op * retrigger test * optimize multi proposal operator (cpu implementation) * open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
…n proposal op (gpu version) (apache#9939) * add multi proposal operator (cpu version) * lint fixed * add the unit-test for MultiProposal Operator * fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu) * fix _nms function in the GPU implementation of Proposal and MultiProposal * use openmp to optimize MultiProposal Operator * fix MultiProposal Building on Windows * rm openmp collapse token in MultiProposal.cc * use signed iterator for MultiProposal * add cpu/gpu consistency test for Proposal and MultiProposal * fix test samples in test_multi_proposal_op * add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile * use stable sort for MultiProposal and change Building Settings back * skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator * add unittest skip for test_multi_proposal_op * retrigger test * optimize multi proposal operator (cpu implementation) * open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
…n proposal op (gpu version) (apache#9939) * add multi proposal operator (cpu version) * lint fixed * add the unit-test for MultiProposal Operator * fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu) * fix _nms function in the GPU implementation of Proposal and MultiProposal * use openmp to optimize MultiProposal Operator * fix MultiProposal Building on Windows * rm openmp collapse token in MultiProposal.cc * use signed iterator for MultiProposal * add cpu/gpu consistency test for Proposal and MultiProposal * fix test samples in test_multi_proposal_op * add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile * use stable sort for MultiProposal and change Building Settings back * skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator * add unittest skip for test_multi_proposal_op * retrigger test * optimize multi proposal operator (cpu implementation) * open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
…n proposal op (gpu version) (apache#9939) * add multi proposal operator (cpu version) * lint fixed * add the unit-test for MultiProposal Operator * fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu) * fix _nms function in the GPU implementation of Proposal and MultiProposal * use openmp to optimize MultiProposal Operator * fix MultiProposal Building on Windows * rm openmp collapse token in MultiProposal.cc * use signed iterator for MultiProposal * add cpu/gpu consistency test for Proposal and MultiProposal * fix test samples in test_multi_proposal_op * add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile * use stable sort for MultiProposal and change Building Settings back * skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator * add unittest skip for test_multi_proposal_op * retrigger test * optimize multi proposal operator (cpu implementation) * open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
Description
The multi_proposal operator (
mxnet.sym.contrib.MultiProposal
, CPU version) is not implemented before.I wrote the code about it.
And I found there was a bug in proposal.cu and multi_proposal.cu.
The
batch_size
of the output of Proposal and MultiProposal are bothparam_.rpn_post_nms_top_n
.When
count_anchors < param_.rpn_post_nms_top_n
, the variablerpn_post_nms_top_n
will becount_anchors
, which is less thanparam_.rpn_post_nms_top_n
.https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cu#L438
It will cause the problem that the output whose index is larger than
rpn_post_nms_top_n
will be not assigned.Checklist
Essentials
make lint
)Changes
PrepareOutput
bug in proposal op (GPU) and multi_proposal op(GPU)Comments
Testing Code
Multi-Proposal Performance Table