Skip to content

improve stability of test_nms_cuda #2044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 3, 2020
Merged

Conversation

hartb
Copy link
Contributor

@hartb hartb commented Apr 1, 2020

This change addresses two issues:

_create_tensors_with_iou() creates test data for the NMS tests. It
takes care to ensure at least one pair of boxes (1st and last) have
IoU around the threshold for the test. However, the constructed
IoU for that pair is so close to the threshold that rounding
differences (presumably) between CPU and CUDA implementations may
result in one suppressing a box in the pair and the other not.
Adjust the construction to ensure the IoU for the box pair is
near the threshold, but far-enough above that both implementations
should agree.

Where 2 boxes have nearly or exactly the same score, the CPU and
CUDA implementations may order them differently. Adjust
test_nms_cuda() to check only that the non-suppressed box lists
include the same members, without regard for ordering.

This change addresses two issues:

_create_tensors_with_iou() creates test data for the NMS tests. It
takes care to ensure at least one pair of boxes (1st and last) have
IoU around the threshold for the test. However, the constructed
IoU for that pair is _so_ close to the threshold that rounding
differences (presumably) between CPU and CUDA implementations may
result in one suppressing a box in the pair and the other not.
Adjust the construction to ensure the IoU for the box pair is
near the threshold, but far-enough above that both implementations
should agree.

Where 2 boxes have nearly or exactly the same score, the CPU and
CUDA implementations may order them differently. Adjust
test_nms_cuda() to check only that the non-suppressed box lists
include the same members, without regard for ordering.
@hartb
Copy link
Contributor Author

hartb commented Apr 1, 2020

See: #2035

@peterjc123 peterjc123 requested a review from fmassa April 1, 2020 15:35
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I'm unclear about a few of things, let me know what you think

test/test_ops.py Outdated
@@ -399,7 +399,7 @@ def test_nms_cuda(self):
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)

self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou))
self.assertTrue(set(r_cpu.tolist()) == set(r_cuda.tolist()), err_msg.format(iou))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned order of the elements is kind of important -- we ensure that the indices are sorted wrt the scores in decreasing order.

Ideally, we would like to keep this test, but take into account the fact that if two elements have the same score, the order doesn't matter.
One way of doing this is by doing something like what PyTorch does for testing topk, roughly in the lines of:

is_eq = torch.allclose(r_cpu, r_cuda.cpu())
if not is_eq:
    # if the indices are not the same, ensure that it's because the scores
    # are duplicate
    is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()])
self.assertTrue(is_eq, err_msg.format(iou)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This problem seems definitely due to lack of stability of torch's tensor.sort(), and that can be recreated with just a little snippet:

$ cat sorttest.py 
import torch

for s in range(150):
    torch.manual_seed(s)
    scores = torch.rand(1000)
    sort_indices_cpu = scores.sort(0, descending=True)[1]
    sort_indices_cuda = scores.cuda().sort(0, descending=True)[1]

    if not torch.allclose(sort_indices_cpu, sort_indices_cuda.cpu()):
        print(f'FAIL with seed {s}:')
        print(f'CPU CUDA')
        for a, b in zip(sort_indices_cpu.tolist(), sort_indices_cuda.tolist()):
            if a != b:
                print(f'{a:3} {b:3}')

$ python sorttest.py 
FAIL with seed 64:
CPU CUDA
810 478
478 810
FAIL with seed 122:
CPU CUDA
 58 427
427  58

Your suggested solution is much better; will update the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a pair of same-scored boxes (A, B) happened to overlap beyond the IoU threshold, then one result would include box A, and the other box B if they sorted in different orders.

I guess a deeper solution for this would be to force the use of stable sort, but it's not clear that torch provides one. The docs for torch.sort don't mention stability either way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch doesn't enforce any guarantees on the order of the sorted elements (specially in CUDA). So I guess we will be taking the same approach here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert this change, and we keep the other change? I think tests should pass in this case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, great catch!

I think adding the division in the cuda kernel is fine, can you fix it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll do some testing and then adjust the PR.

It may be later in the day today; have a few other things competing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. Made the change to devIoU() and now the example above passes, but about an equal number of tests with different boxes fail. And I can see that there are still slight differences in the IoU computation between the CPU and CUDA routines.

The new results are consistent with the histogram in the other thread here:

  • failures are strongly clustered around the point where the constructed 1st/last-box IoU almost exactly matches the target threshold IoU
  • where the constructed IoU is slightly over-target, the CPU code will sometimes suppress the box when CUDA code doesn't
  • and vice versa when the constructed IoU is slightly under-target

I'll keep looking but so far don't see other low-hanging fruit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha! If I add printfs of partial results in those routines, the mismatches go away. Just building without optimizations doesn't seem to prevent the mismatches, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the investigation!

I think for the interest of getting CI green, let's move forward with the mitigation approach you proposed (do not use exact iou), but would be great in the future to get to the bottom of this.

test/test_ops.py Outdated
@@ -378,7 +378,7 @@ def _create_tensors_with_iou(self, N, iou_thresh):
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / (iou_thresh - 1e-5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this might fix the tests, I'm not clear where in the current implementation the rounding is making things behave differently, so the "bug" still persists, but it's not tested anymore.

Do you have any insights on where the discrepancy might be?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not precisely, but with this code driving the existing test functions:

N = 30

for s in range(50):
    for iou in [0.2, 0.5, 0.8]:
        torch.manual_seed(s)
        boxes, scores = _create_tensors_with_iou(N, iou)
        r_cpu = ops.nms(boxes, scores, iou).tolist()
        r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou).tolist()
        if r_cpu != r_cuda:
            excess_boxes = "CPU" if set(r_cpu).difference(set(r_cuda)) else "CUDA"
            endbox_iou = ops.box_iou(boxes[0].unsqueeze(0), boxes[N-1].unsqueeze(0))
            print(f'Seed: {s:3}, Excess: {excess_boxes:4}, iou diff: {(iou - endbox_iou[0][0]):+.12f}')

I can see that the excess (i.e. non-removed by CPU or CUDA) box in each test appears to depend on whether the difference between the target / threshold IoU value and the constructed 1st/last-box IoU is positive or negative:

$ python test_foo.py
Seed:   3, Excess: CUDA, iou diff: -0.000000059605
Seed:   7, Excess: CPU , iou diff: +0.000000000000
Seed:  12, Excess: CUDA, iou diff: -0.000000119209
Seed:  14, Excess: CPU , iou diff: +0.000000000000
Seed:  25, Excess: CUDA, iou diff: -0.000000059605
Seed:  31, Excess: CPU , iou diff: +0.000000000000
Seed:  32, Excess: CPU , iou diff: +0.000000059605
Seed:  35, Excess: CUDA, iou diff: -0.000000029802
Seed:  41, Excess: CPU , iou diff: +0.000000000000

The differences between target IoU and constructed IoU are pretty narrow in the failing cases. Searching over 5000 random seeds, I see 1421 failures, with the target/constructed differences ranging from -1.788139343262e-07 to 1.192092895508e-07

In the passing cases, the target/constructed IoU differences are generally larger. I see differences ranging from -8.642673492432e-06 through 1.275539398193e-05 (with one outlier at 0.002531647682!).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easier to see what I'm describing there as a histogram of difference between target & constructed IoU vs whether the test passes or shows extra boxes in CPU or CUDA result:

Threshold - constructed IoU Pass  CPU CUDA
--------------------------- ---- ---- ----
-0.000008642673492431640625    1    0    0
-0.000005960464477539062500    1    0    0
-0.000005662441253662109375    1    0    0
-0.000004649162292480468750    1    0    0
-0.000003635883331298828125    1    0    0
-0.000002622604370117187500    1    0    0
-0.000002503395080566406250    1    0    0
-0.000002205371856689453125    1    0    0
-0.000001966953277587890625    1    0    0
-0.000001788139343261718750    1    0    0
-0.000001311302185058593750    1    0    0
-0.000001251697540283203125    2    0    0
-0.000001192092895507812500    2    0    0
-0.000001013278961181640625    1    0    0
-0.000000894069671630859375    3    0    0
-0.000000834465026855468750    1    0    0
-0.000000715255737304687500    4    0    0
-0.000000655651092529296875    1    0    0
-0.000000596046447753906250    5    0    0
-0.000000536441802978515625    3    0    0
-0.000000476837158203125000    7    0    0
-0.000000417232513427734375    5    0    0
-0.000000357627868652343750    9    0    0
-0.000000298023223876953125   14    0    0
-0.000000238418579101562500   26    0    0
-0.000000178813934326171875   69    0   20
-0.000000119209289550781250  254    0   92
-0.000000059604644775390625  684    0  298
 0.000000000000000000000000 1423   98    0
 0.000000059604644775390625 1035  159    0
 0.000000119209289550781250  484   40    0
 0.000000178813934326171875  140    0    0
 0.000000238418579101562500   42    0    0
 0.000000298023223876953125   13    0    0
 0.000000357627868652343750    7    0    0
 0.000000417232513427734375    8    0    0
 0.000000476837158203125000    7    0    0
 0.000000536441802978515625    7    0    0
 0.000000596046447753906250    3    0    0
 0.000000655651092529296875    6    0    0
 0.000000715255737304687500    2    0    0
 0.000001013278961181640625    1    0    0
 0.000001072883605957031250    1    0    0
 0.000001132488250732421875    1    0    0
 0.000001192092895507812500    1    0    0
 0.000001370906829833984375    1    0    0
 0.000001430511474609375000    1    0    0
 0.000001490116119384765625    1    0    0
 0.000001549720764160156250    1    0    0
 0.000001728534698486328125    2    0    0
 0.000001966953277587890625    1    0    0
 0.000012755393981933593750    1    0    0
 0.002531647682189941406250    1    0    0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be ok with merging this change (and thus not really exercising the test case I wanted with the original PR introducing this change), but I think it's more trouble than it's worth it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And a concrete failing snippet:

$ cat nms.py 
import torch
from torchvision import ops

torch.set_printoptions(precision=12)

iou = 0.8

scores = torch.tensor([0.75, 0.25])
boxes = torch.tensor([[ 0.426352024078369140625000, 10.556936264038085937500000,
                       29.010599136352539062500000, 13.252419471740722656250000],
                      [0, 0, 0, 0]])

boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
boxes[-1, 2] += (x1 - x0) * (1 - iou) / iou

print("Box[0]: ", boxes[0])
print("Box[1]: ", boxes[1])
print("IoU target: ", iou, ", IoU actual: ", ops.box_iou(boxes[0].unsqueeze(0), boxes[1].unsqueeze(0)).squeeze(0).item())
print("CPU:  ", ops.nms(boxes, scores, iou))
print("CUDA: ", ops.nms(boxes.cuda(), scores.cuda(), iou))

Which gives:

$ python nms.py 
Box[0]:  tensor([ 0.426352024078, 10.556936264038, 29.010599136353, 13.252419471741])
Box[1]:  tensor([ 0.426352024078, 10.556936264038, 36.156661987305, 13.252419471741])
IoU target:  0.8 , IoU actual:  0.8000000715255737
CPU:   tensor([0])
CUDA:  tensor([0, 1], device='cuda:0')

hartb added 2 commits April 1, 2020 19:51
The CPU and CUDA nms implementations each sort the box scores
as part of their work, but the sorts they use are not stable. So
boxes with the same score maybe be processed in opposite order
by the two implmentations.

Relax the assertion in test_nms_cuda (following the model in
pytorch's test_topk()) to allow the test to pass if the output
differences are caused by similarly-scored boxes.
Adjust _create_tensors_with_iou() to ensure we create at least
one box just over threshold that should be suppressed.
@codecov-io
Copy link

Codecov Report

Merging #2044 into master will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@          Coverage Diff           @@
##           master   #2044   +/-   ##
======================================
  Coverage    0.48%   0.48%           
======================================
  Files          92      92           
  Lines        7411    7411           
  Branches     1128    1128           
======================================
  Hits           36      36           
  Misses       7362    7362           
  Partials       13      13

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f2f085b...39ff636. Read the comment docs.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the work an investigation!

Let's merge this as is to unblock CI, and let's keep investigating this more in the future (to hopefully get to the bottom of it)

@fmassa fmassa merged commit e61538c into pytorch:master Apr 3, 2020
@fmassa fmassa mentioned this pull request Apr 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants