Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-40]add multi proposal operator (cpu version) and fix the bug in proposal op (gpu version) #9939

Merged
merged 19 commits into from
Mar 20, 2018

Conversation

wkcn
Copy link
Member

@wkcn wkcn commented Mar 1, 2018

Description

The multi_proposal operator (mxnet.sym.contrib.MultiProposal, CPU version) is not implemented before.
I wrote the code about it.

And I found there was a bug in proposal.cu and multi_proposal.cu.
The batch_size of the output of Proposal and MultiProposal are both param_.rpn_post_nms_top_n.
When count_anchors < param_.rpn_post_nms_top_n, the variable rpn_post_nms_top_n will be count_anchors, which is less than param_.rpn_post_nms_top_n.
https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cu#L438

It will cause the problem that the output whose index is larger than rpn_post_nms_top_n will be not assigned.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Multi Proposal Operator is supported on the CPU context.
  • Add unit-test for Multi Proposal Operator
  • Add cpu/gpu consistency test for Proposal and MultiProposal
  • Fix PrepareOutput bug in proposal op (GPU) and multi_proposal op(GPU)
  • Fix the output's difference between Proposal (CPU) and Proposal (GPU)
  • Use omp to optimize the loops in Proposal(CPU) and MultiProposal(CPU)

Comments

Testing Code

Multi-Proposal Performance Table

Performance CPU(no OpenMP) CPU(OpenMP) GPU
Time(s) 33.899 5.049 4.435

@wkcn wkcn requested a review from cjolivier01 as a code owner March 1, 2018 06:10
@TaoLv
Copy link
Member

TaoLv commented Mar 1, 2018

Thanks for adding the CPU implementation.. Would you like to add test cases for it, so the operator can be checked in jenkins.

@wkcn
Copy link
Member Author

wkcn commented Mar 1, 2018

@TaoLv Thank you!
It seems that there is no any unit-test for Proposal Operator. But the test of Multi-Proposal relies on that of Proposal.
Which file I should add an unit-test for the Multi-Proposal operator into?
Is it enough to use the testing code I wrote in Comments?

@wkcn wkcn changed the title add multi proposal operator (cpu version) add multi proposal operator (cpu version) and fix the bug in proposal op (gpu version) Mar 1, 2018
@@ -553,10 +553,10 @@ class ProposalGPUOp : public Operator{
cudaMemcpyHostToDevice));

// copy results after nms
dimGrid.x = (rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
dimGrid.x = (param_.rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should param_.rpn_post_nms_top_n be used here? Or rpn_post_nms_top_n computed here originally https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cu#L438 ? What's the bug in the original version? cc @precedenceguo

Copy link
Member Author

@wkcn wkcn Mar 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZiyueHuang If count_anchor < param_.rpn_post_nms_top_n,
rpn_post_nms_top_n will be equal to count_anchor,
and rpn_post_nms_top_n will be less than param_.rpn_post_nms_top_n.

https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal-inl.h#L119
It shows the batch_size of the output is param_.rpn_post_nms_top_n instead of rpn_post_nms_top_n.

The original version will not assign the element whose index is larger than rpn_post_nms_top_n

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Elements whose index is larger than rpn_post_nms_top_n should be assigned. The values are copied from the valid anchors, just to satisfy the output size requirement.

Copy link
Member Author

@wkcn wkcn Mar 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@precedenceguo Thank you!
I have a question.
The score of the anchor boxes whose size are less than rpn_min_size will be assigned to -1 (FilterBox), then these anchor boxes may be NMS and selected as the output in the implementation. Are they valid anchors?

I think they are invalid and should be thrown away.

Reference:
https://github.com/precedenceguo/mx-rcnn/blob/master/rcnn/symbol/proposal.py#L116

Copy link
Contributor

@ijkguo ijkguo Mar 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they should be thrown out; but that would require a new array. Instead I set the predicted confidence to be -1, same as the boxes whose centers are out of the input image boundary. Those invalid anchors would be ranked to the bottom for NMS so a proper threshold did not produce significant problems.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!
The invalid anchors would be ranked to the bottom for NMS, but the threshold is for overlap between two anchors rather than score.
Should we add a condition that score != -1 in NMS?

https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cc#L259

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good idea but I did not find significant issue of the invalid anchors.

@piiswrong
Copy link
Contributor

Could you add a cpu/gpu consistency test in tests/python/gpu/test_operator_gpu.py?

@wkcn
Copy link
Member Author

wkcn commented Mar 1, 2018

@piiswrong Yes, I will add it.
I found Proposal OP (CPU implementation) uses unstable sort (std::sort),
https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cc#L195

and Proposal OP (GPU implementation) uses stable sort (thrust::stable_sort_by_key).
https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cu#L517

And the accuracy of computation between CPU and GPU is different, it causes that NMS selects different boxes between CPU and GPU.

When I replace std::sort with std::stable_sort in the CPU implementation, the cpu/gpu consistency test is passed.

@wkcn
Copy link
Member Author

wkcn commented Mar 2, 2018

I wrote a cpu/gpu consistency test for Proposal and MultiProposal.
I found there is difference between the CPU output and the GPU output for mx.nd.contrib.Proposal.

It seems that the index order of the Non-Maximum-Suppression result may be different between the CPU implementation and the GPU implementation.
And another problem is that it may need to add the condition num_to_keep < rpn_post_nms_top_n
https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cu#L341
reference: https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/proposal.cc#L235

Here is the cpu/gpu consistency test.

import mxnet as mx
import numpy as np

# @with_seed()
def test_multi_proposal_op():
    # paramters
    feature_stride = 16
    scales = (8, 16, 32)
    ratios = (0.5, 1, 2)
    rpn_pre_nms_top_n = 12000
    rpn_post_nms_top_n = 2000
    threshold = 0.7
    rpn_min_size = 16

    feat_len = 14
    H, W = feat_len, feat_len
    num_anchors = len(scales) * len(ratios)
    count_anchors = H * W * num_anchors

    def get_new_data(batch_size, ctx):
        '''
        cls_prob: (batch_size, 2 * num_anchors, H, W)
        bbox_pred: (batch_size, 4 * num_anchors, H, W)
        im_info: (batch_size, 3)
        '''

        cls_prob = mx.nd.empty((batch_size, 2 * num_anchors, H, W), dtype = np.float32, ctx = ctx)
        bbox_pred = mx.nd.empty((batch_size, 4 * num_anchors, H, W), dtype = np.float32, ctx = ctx)
        im_info = mx.nd.empty((batch_size, 3), dtype = np.float32, ctx = ctx)

        cls_prob = mx.nd.array(np.random.random(cls_prob.shape), ctx = ctx)
        bbox_pred = mx.nd.array(np.random.random(bbox_pred.shape), ctx = ctx)

        for i in range(batch_size):
            im_size = np.random.randint(100, feat_len * feature_stride, size = (2,))
            im_scale = np.random.randint(70, 100) / 100.0
            im_info[i, :] = [im_size[0], im_size[1], im_scale]
        return cls_prob, bbox_pred, im_info

    def check_proposal_consistency(op, batch_size):
        '''
            op is mx.nd.contrib.Proposal or mx.nd.contrib.MultiProposal
        '''
        cls_prob, bbox_pred, im_info = get_new_data(batch_size, mx.cpu(0))
        rois_cpu, score_cpu = op(
                cls_score = cls_prob,
                bbox_pred = bbox_pred,
                im_info = im_info,
                feature_stride = feature_stride,
                scales = scales,
                ratios = ratios,
                rpn_pre_nms_top_n = rpn_pre_nms_top_n,
                rpn_post_nms_top_n = rpn_post_nms_top_n,
                threshold = threshold,
                rpn_min_size = rpn_min_size, output_score = True)

        gpu_ctx = mx.gpu(0)

        # copy data to gpu from cpu
        cls_prob_gpu = cls_prob.as_in_context(gpu_ctx)
        bbox_pred_gpu = bbox_pred.as_in_context(gpu_ctx)
        im_info_gpu = im_info.as_in_context(gpu_ctx)

        rois_gpu, score_gpu = op(
                cls_score = cls_prob_gpu,
                bbox_pred = bbox_pred_gpu,
                im_info = im_info_gpu,
                feature_stride = feature_stride,
                scales = scales,
                ratios = ratios,
                rpn_pre_nms_top_n = rpn_pre_nms_top_n,
                rpn_post_nms_top_n = rpn_post_nms_top_n,
                threshold = threshold,
                rpn_min_size = rpn_min_size, output_score = True)

        print (rois_cpu.asnumpy(), rois_gpu.asnumpy())
        assert np.allclose(rois_cpu.asnumpy(), rois_gpu.asnumpy())
        assert np.allclose(score_cpu.asnumpy(), score_gpu.asnumpy())

    check_proposal_consistency(mx.nd.contrib.Proposal, 1)
    check_proposal_consistency(mx.nd.contrib.MultiProposal, 20)

test_multi_proposal_op()
print ("test ok")

Copy link
Contributor

@pengzhao-intel pengzhao-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wkcn Great works :)

I see there're several nested for loop and I think you can parallelize them by OpenMP for the better performance.

Would you mind take a try?

If any help is needed, feel free to let me know.

for (int a = 0; a < anchors; ++a) {
for (int h = 0; h < heights; ++h) {
for (int w = 0; w < widths; ++w) {
index_t index = h * (widths * anchors) + w * (anchors) + a;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to parallel the for-loop by OMP?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one do you think should be parallelized? always the outside one? the one with the largest number of iterations? the smallest number of iterations? what’s the convention on this (i’ve been curious of this myself)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjolivier01 I think we can merge the three for-loop to a for-loop, then use OMP to optimize it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wkcn you don't need. collapse can work as your expectation :)
https://software.intel.com/en-us/articles/openmp-loop-collapse-directive

Merging the three for-loop will make the code hard to understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjolivier01 In general, we prefer to parallelize in the outer loop because more tasks can be run simultaneously, like task-level parallelization. When we parallelize the inner-loop (not this case), most likely we want to do vectorization by OMP (simd).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel Thank you:) I will read the document and try it.


for (int a = 0; a < anchors; ++a) {
for (int h = 0; h < heights; ++h) {
for (int w = 0; w < widths; ++w) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to parallel the for-loop by OMP?


// calculate nms
*out_size = 0;
for (index_t i = 0; i < dets.size(0) && (*out_size) < post_nms_top_n; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to parallel the for-loop by OMP?

@wkcn
Copy link
Member Author

wkcn commented Mar 2, 2018

@pengzhao-intel Thank you! I will have a try.

I use #pragma omp paraller for for each for-loop in Multi Proposal (cpu implementation),
But the performance doesn't improve.
Maybe it costs a little calculation.

@pengzhao-intel
Copy link
Contributor

@wkcn What's the size of your for-loop in the test-case?
How can I test your code?

@wkcn
Copy link
Member Author

wkcn commented Mar 3, 2018

@pengzhao-intel Here is the testing code.
https://gist.github.com/wkcn/4a09c142bc9886b45b5a23461bbe4733

I found that I made a mistake that I didn't use nd.waitall() to test the performance.
If not using nd.waitall(), the calculation will not execute because of lazy-evaluation.

performance CPU(no omp) CPU(omp) GPU
Time(s) 33.899 12.432 4.435

However, when I set the environment variables MXNET_OMP_MAX_THREADS or OMP_NUM_THREADS, it may bring bad performance.

Update: 2018-03-09 (626296b)

performance CPU(no omp) CPU(omp) GPU
Time(s) 33.899 5.049 4.435

@wkcn wkcn requested a review from szha as a code owner March 3, 2018 14:18
@pengzhao-intel
Copy link
Contributor

@wkcn Nice work for the parallelization by OMP. 2X speedup is a good start point.

I see you have added several OMP in the code.
Could you help profile where is the performance bottleneck now?

PS: you may need to add gettimeofdate in your C code and identify the most time-consuming loop (or section).

@xinyu-intel
Copy link
Contributor

@wkcn @pengzhao-intel I think there are so much units in multi_proposal.cc. We can help find which unit is the bottleneck and then fix it.

@wkcn
Copy link
Member Author

wkcn commented Mar 5, 2018

@pengzhao-intel @xinyu-intel
Thank you! I will have a try.
The performance table

name time (ms)
BBoxTransformInv 268
IoUTransformInv Not used
FilterBox 22
CopyScore 18
ReverseArgsort(unstable sort) 7303
ReorderProposals 338
nms(calculate area) 286
nms(calcuate nms) 7547
allocate memory for workspace 1
copy anchor to workspace_proposal 0
enumrate all shifted anchors 9
copy workspace_proposals_base to workspace_proposals 162
assign foreground scores for each anchor 45
prepare output 3
Total 16002

Using stable sort to sort anchors (ReverseArgsort) will increase about 3000 ms.

Update: 2018-03-09 (626296b)

name time (ms)
BBoxTransformInv 523
FilterBox 35
CopyScore 49
ReverseArgsort(unstable sort) 1036
ReorderProposals 17
nms 2573
allocate memory for workspace 4
GenerateAnchors 0
copy anchor and score to workspace_proposal 361
enumerate all shifted anchors 101
prepare output 0
Total 4699

@wkcn
Copy link
Member Author

wkcn commented Mar 5, 2018

For the cpu/gpu consistency test, it's interesting that the number of valid anchors in the CPU implementation and the GPU implementation may be different.
The reason is that the float precision in CPU's and GPU's is different.
The size of some anchors may be near the margin of the minimal valid anchor and overlap.

The margin of the minimal valid anchor:
https://github.com/wkcn/incubator-mxnet/blob/add_multi_proposal_cpu_version/src/operator/contrib/multi_proposal.cc#L141
https://github.com/wkcn/incubator-mxnet/blob/add_multi_proposal_cpu_version/src/operator/contrib/multi_proposal.cc#L159

Overlap:
https://github.com/wkcn/incubator-mxnet/blob/add_multi_proposal_cpu_version/src/operator/contrib/multi_proposal.cc#L273

I want to create testing sample to avoid these margins.

@marcoabreu
Copy link
Contributor

marcoabreu commented Mar 6, 2018 via email

@wkcn
Copy link
Member Author

wkcn commented Mar 6, 2018

@marcoabreu Thank you:) I will commit it soon.

@CodingCat
Copy link
Contributor

Hi, the community has passed to vote about associating the code changes with JIRA (https://lists.apache.org/thread.html/ab22cf0e35f1bce2c3bf3bec2bc5b85a9583a3fe7fd56ba1bbade55f@%3Cdev.mxnet.apache.org%3E)

We have updated the guidelines for contributors in https://cwiki.apache.org/confluence/display/MXNET/Development+Process, please ensure that you have created a JIRA at https://issues.apache.org/jira/projects/MXNET/issues/ to describe your work in this pull request and include the JIRA title in your PR as [MXNET-xxxx] your title where MXNET-xxxx is the JIRA id

Thanks!

@pengzhao-intel
Copy link
Contributor

@wkcn thanks for the detail profiling data.

Regarding sorting, you can use MKL sort in case the library is set to MKL. It will be much faster than sequential sort.
https://software.intel.com/en-us/mkl-developer-reference-c-lasrt

Regarding calculating nms, the current implementation is not very good for the better performance. @xinyu-intel will do several experiments in local and back to you later.

@wkcn
Copy link
Member Author

wkcn commented Mar 7, 2018

@pengzhao-intel Thank you!

@wkcn
Copy link
Member Author

wkcn commented Mar 7, 2018

@CodingCat I will read it. Thank you!

@wkcn wkcn changed the title add multi proposal operator (cpu version) and fix the bug in proposal op (gpu version) [MXNET-40]add multi proposal operator (cpu version) and fix the bug in proposal op (gpu version) Mar 7, 2018
@wkcn
Copy link
Member Author

wkcn commented Mar 12, 2018

@TaoLv @ZiyueHuang @precedenceguo @piiswrong @pengzhao-intel @xinyu-intel @marcoabreu @CodingCat

Hello!

There are the changes in this PR.

  1. I added the CPU implementation on Multi-Proposal Operator,
    and I used OpenMP to optimize it.
    The performance of the CPU implementation is close to that of the GPU implementation.

    Performance Table

Performance CPU(no OpenMP) CPU(OpenMP) GPU
Time(s) 33.899 5.049 4.435

Testing Code

  1. I wrote the unit-test for Multi-Proposal and the CPU/GPU consistency test for Proposal and Multi-Proposal.
  1. I fixed these bugs below:
    _nms function in the GPU implementation
    It needs to break the loop when num_to_keep >= rpn_post_nms_top_n
    https://github.com/apache/incubator-mxnet/pull/9939/files#diff-8ce45deadbf52bb9d13aa3bb9562dddaR367

    The output of Proposal and Multi-Proposal
    The argument of PrepareOutput should be param_.rpn_post_nms_top_n rather than rpn_post_nms_top_n
    https://github.com/apache/incubator-mxnet/pull/9939/files#diff-8ce45deadbf52bb9d13aa3bb9562dddaR572

For Changes Requested, I have removed the USE_STABLE_SORT_FOR_PROPOSAL flag and recovered the building setting.

Could you please merge the PR?

Thank you! :-)

@wkcn
Copy link
Member Author

wkcn commented Mar 15, 2018

@marcoabreu @szha @cjolivier01

Hello!

For Changes Requested, I have removed the USE_STABLE_SORT_FOR_PROPOSAL flag and recovered the building setting.

Could you please review and merge the PR?

Thank you!

Copy link
Contributor

@marcoabreu marcoabreu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wkcn
Copy link
Member Author

wkcn commented Mar 16, 2018

@marcoabreu Thank you! :)

//=====================
// NMS Utils
//=====================
namespace mxnet {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check the sort_op in src/operator//tensor/sort_op.h?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhreshold Thank you!
It seems that I can use SortByKey in src/operator/tensor/sort_op.h, may replace the sort in proposal.cc and multi_proposal.cc after that.
However, SortByKey in src/operator/tensor/sort_op.h uses stable_sort, which is slightly slower than std::sort

@piiswrong piiswrong merged commit 6bc2d3f into apache:master Mar 20, 2018
@piiswrong
Copy link
Contributor

@wkcn Hi looks like the multi proposal operator fails randomly. Do you have any idea?
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/PR-10161/5/pipeline/549#step-919-log-1763

@wkcn
Copy link
Member Author

wkcn commented Mar 22, 2018

@piiswrong
Sorry, I think the score of invalid anchors is assigned to -1. The CPU implementation uses unstable-sort (std::sort), but GPU implementation uses stable-sort(thrust::stable_sort_by_key). It leads to inconsistency in invalid anchors.

The cpu/gpu consistency couldn't be guaranteed when the output has invalid anchors whose score is -1.

https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/multi_proposal.cc#L89
https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/multi_proposal.cc#L141
https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/multi_proposal.cc#L159

I followed Proposal Operator and Multi Proposal (GPU) to implement Multi Proposal (CPU).
The output of Proposal Operator may have invalid anchors, but invalid anchors is needless. (reference:
https://github.com/apache/incubator-mxnet/blob/master/example/rcnn/rcnn/symbol/proposal.py#L133)

There are three ways to solve the cpu/gpu consistency test in Proposal and MultiProposal:

  1. Ignoring the comparison between invalid anchors whose score is -1
  2. Using stable-sort for the CPU implementation
  3. Removing invalid anchors from the output.

I will fix it soon.

@marcoabreu
Copy link
Contributor

Hello @wkcn, thanks a lot for the quick response! Do you have an ETA when this is going to be resolved? As stated by eric, we're experiencing a significant number of test failures which is impacting the overall results of our CI.

@marcoabreu
Copy link
Contributor

Feel free to create a PR which disables that test, I'll merge it asap.

@wkcn
Copy link
Member Author

wkcn commented Mar 22, 2018

@marcoabreu
Hello! I have created a new PR to disable the cpu/gpu consistency test temporarily.

I will fix the bug.

@wkcn
Copy link
Member Author

wkcn commented Mar 23, 2018

@piiswrong @marcoabreu
Hello

I found the reason of the CPU/GPU consistency test failed.
It seems that the float computation is different between CPU and GPU.

When doing Non Maximum Suppression and the overlap of anchors is close to the threshold, the overlap may be greater than the threshold in the CPU implementation, but that is less than the threshold in the GPU implementation.

nms cpu
nms gpu

Here is the output (score) comparison between the Multi Proposal CPU/GPU implementations.

LINE    score(cpu) score(gpu)
10920 0.881283_0.881283
10921 0.881242_0.881242
10922 0.880963_0.880963
10923 0.880916_0.880916
10924 0.880834_0.880834
10925 0.880719_[0.880812] HERE
10926 0.880509_0.880719
10927 0.880333_0.880509
10928 0.880228_0.880333
10929 0.879861_0.880228

ashokei pushed a commit to ashokei/incubator-mxnet that referenced this pull request Mar 27, 2018
…n proposal op (gpu version) (apache#9939)

* add multi proposal operator (cpu version)

* lint fixed

* add the unit-test for MultiProposal Operator

* fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu)

* fix _nms function in the GPU implementation of Proposal and MultiProposal

* use openmp to optimize MultiProposal Operator

* fix MultiProposal Building on Windows

* rm openmp collapse token in MultiProposal.cc

* use signed iterator for MultiProposal

* add cpu/gpu consistency test for Proposal and MultiProposal

* fix test samples in test_multi_proposal_op

* add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile

* use stable sort for MultiProposal and change Building Settings back

* skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator

* add unittest skip for test_multi_proposal_op

* retrigger test

* optimize multi proposal operator (cpu implementation)

* open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
…n proposal op (gpu version) (apache#9939)

* add multi proposal operator (cpu version)

* lint fixed

* add the unit-test for MultiProposal Operator

* fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu)

* fix _nms function in the GPU implementation of Proposal and MultiProposal

* use openmp to optimize MultiProposal Operator

* fix MultiProposal Building on Windows

* rm openmp collapse token in MultiProposal.cc

* use signed iterator for MultiProposal

* add cpu/gpu consistency test for Proposal and MultiProposal

* fix test samples in test_multi_proposal_op

* add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile

* use stable sort for MultiProposal and change Building Settings back

* skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator

* add unittest skip for test_multi_proposal_op

* retrigger test

* optimize multi proposal operator (cpu implementation)

* open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
…n proposal op (gpu version) (apache#9939)

* add multi proposal operator (cpu version)

* lint fixed

* add the unit-test for MultiProposal Operator

* fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu)

* fix _nms function in the GPU implementation of Proposal and MultiProposal

* use openmp to optimize MultiProposal Operator

* fix MultiProposal Building on Windows

* rm openmp collapse token in MultiProposal.cc

* use signed iterator for MultiProposal

* add cpu/gpu consistency test for Proposal and MultiProposal

* fix test samples in test_multi_proposal_op

* add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile

* use stable sort for MultiProposal and change Building Settings back

* skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator

* add unittest skip for test_multi_proposal_op

* retrigger test

* optimize multi proposal operator (cpu implementation)

* open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
…n proposal op (gpu version) (apache#9939)

* add multi proposal operator (cpu version)

* lint fixed

* add the unit-test for MultiProposal Operator

* fix PrepareOutput bug in proposal op(gpu) and multi_proposal op(gpu)

* fix _nms function in the GPU implementation of Proposal and MultiProposal

* use openmp to optimize MultiProposal Operator

* fix MultiProposal Building on Windows

* rm openmp collapse token in MultiProposal.cc

* use signed iterator for MultiProposal

* add cpu/gpu consistency test for Proposal and MultiProposal

* fix test samples in test_multi_proposal_op

* add USE_STABLE_SORT_PROPOSAL=1 for Jenkinsfile

* use stable sort for MultiProposal and change Building Settings back

* skip the cpu/gpu consistency test for Proposal Operator and MultiProposal Operator

* add unittest skip for test_multi_proposal_op

* retrigger test

* optimize multi proposal operator (cpu implementation)

* open the cpu/gpu consistency test for Proposal Op and MultiProposal Op
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.