Skip to content

test_nms_cuda is flaky #2035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
peterjc123 opened this issue Mar 31, 2020 · 13 comments
Closed

test_nms_cuda is flaky #2035

peterjc123 opened this issue Mar 31, 2020 · 13 comments

Comments

@peterjc123
Copy link
Contributor

🐛 Bug

https://app.circleci.com/pipelines/github/pytorch/vision/2097/workflows/661fd235-202a-4c88-be4d-f8af378c195f/jobs/110511

================================== FAILURES ===================================
___________________________ NMSTester.test_nms_cuda ___________________________

self = <test_ops.NMSTester testMethod=test_nms_cuda>

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
    def test_nms_cuda(self):
        err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'
    
        for iou in [0.2, 0.5, 0.8]:
            boxes, scores = self._create_tensors_with_iou(1000, iou)
            r_cpu = ops.nms(boxes, scores, iou)
            r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
    
>           self.assertTrue(torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou))
E           RuntimeError: The size of tensor a (461) must match the size of tensor b (460) at non-singleton dimension 0

test\test_ops.py:403: RuntimeError
@peterjc123
Copy link
Contributor Author

@fmassa Do you know what happened there?

@fmassa
Copy link
Member

fmassa commented Mar 31, 2020

I have seen this error before, and I thought I had fixed it in #1556 but maybe there is a corner case missing sometimes...

@hartb
Copy link
Contributor

hartb commented Mar 31, 2020

Happened to be looking at this today. The mismatch is because sometimes either the CPU or the CUDA result includes either box 0 or box 999 (depending on their relative scores) while the other does not.

I see the problem on both x86 and Power, and can force it consistently across them by seeding the torch RNG with specific values. For example, I'll see a failure with:

$ diff -u test_ops.py.orig test_ops.py
--- test_ops.py.orig    2020-03-31 17:23:09.453600964 -0400
+++ test_ops.py 2020-03-31 17:23:36.909953703 -0400
@@ -365,6 +365,7 @@
     def test_nms_cuda(self):
         err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'
 
+        torch.manual_seed(12)
         for iou in [0.2, 0.5, 0.8]:
             boxes, scores = self._create_tensors_with_iou(1000, iou)
             r_cpu = ops.nms(boxes, scores, iou)

The problem seems to be the new code in _create_tensors_with_iou() which tries to ensure the 1st and last boxes have an IoU to match the test threshold. The IoU that results from that code is close enough to the threshold that one box or the other is excluded or not, by CPU or CUDA, depending on rounding.

I'm able to avoid the problem by adjusting that to ensure that the 1st/last box IoU is always a bit above the IoU test threshold by introducing an epsilon value:

--- test_ops.py.orig    2020-03-31 21:34:14.662355038 +0000
+++ test_ops.py 2020-03-31 21:34:42.911263943 +0000
@@ -379,7 +379,7 @@
         boxes[:, 2:] += boxes[:, :2]
         boxes[-1, :] = boxes[0, :]
         x0, y0, x1, y1 = boxes[-1].tolist()
-        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
+        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / (iou_thresh - 1e-5)
         scores = torch.rand(N)
         return boxes, scores

That reliably gets rid of the tensor length difference complaints, but the test still sometimes asserts. Looking at that now.

You can get a better look by driving the test routines with something like:

N = 30

for s in range(100):
    for iou in [0.2, 0.5, 0.8]:
        torch.manual_seed(s)
        boxes, scores = _create_tensors_with_iou(N, iou)
        endbox_iou = ops.box_iou(boxes[0].unsqueeze(0), boxes[N-1].unsqueeze(0)).tolist()[0][0]
        r_cpu = set(ops.nms(boxes, scores, iou).tolist())
        r_cuda = set(ops.nms(boxes.cuda(), scores.cuda(), iou).tolist())
        if r_cpu != r_cuda:
            print(f'Seed: {s:3}, IoU: {iou}, Endbox IoU: {endbox_iou:.12f}, Mismatch: {list(r_cpu.symmetric_difference(r_cuda))}')

@hartb
Copy link
Contributor

hartb commented Mar 31, 2020

And looks like the remaining assert I'm seeing is an ordering problem when two of the boxes which make the cut have very similar scores. So:

  • r_cpu includes [...., 554, 788, ...]
  • r_cuda includes [..., 788, 554, ...]

Where the scores of boxes are seemingly identical:

  • 554 -> 0.60010164976119995117
  • 788 -> 0.60010164976119995117

Not sure of the best way to avoid that. Could compare the results without respect to the original ordering (e.g. by sorting or converting to sets), but maybe including the ordering is nice. Not sure of a sensible way to ensure all the boxes have distinct scores, short of assigning them incrementally and that doesn't seem ideal.

@hartb
Copy link
Contributor

hartb commented Mar 31, 2020

Oh, and I'm testing on linux on both IBM Power and x86, so this isn't a Windows-specific problem, despite where the CI failure arose.

Testing with the v0.5.1 branch against pytorch 1.4.0 and 1.5.0-rc

@peterjc123 peterjc123 changed the title Windows CI failed test_nms_cuda is flaky Apr 1, 2020
@hartb
Copy link
Contributor

hartb commented Apr 1, 2020

Put up a PR that adjusts the 1st/last box IoU construction (to ensure it's slightly over threshold so should be consistently suppressed between CPU and CUDA implementations), and to compare the surving box lists without regard for ordering as sets.

Happy to rework those if other solutions are preferred.

@hartb
Copy link
Contributor

hartb commented Apr 1, 2020

Hmmm. Occurs to me now that the 2nd issue (ordering of boxes with similar scores) might be caused by the CUDA implementation not using a stable sort

@fmassa
Copy link
Member

fmassa commented Apr 1, 2020

@hartb thanks for the PR!

I think we should not try to guarantee that the sorting returns the same results between CPU and CUDA (PyTorch doesn't enforce this), but we should enforce that the number of returned elements is the same.

I commented on the PR, let me know what you think

@fmassa
Copy link
Member

fmassa commented Apr 3, 2020

This has been mitigated in #2044, but it would be great if we could understand more deeply why there are those discrepancies (and fix it if possible)

@hartb
Copy link
Contributor

hartb commented Apr 6, 2020

I looked at this some more...

If devIoU() is changed as discussed in #2044 (use a division while calculating IoU, similar to CPU calculation) then it's possible compare the overlap values calculated by CPU vs CUDA.

With that change, overlap values calculated by CPU vs CUDA usually agree exactly, but they can differ from one another by up to 4 ULP (very rarely). Across 1000 seeds:

ULP diff   # instances
0          828
1           16
2          115
3           40
4            1

Even when the calculated overlaps differ, the testcase will still pass unless the overlap values straddle the threshold (i.e. one overlap greater than the threshold, the other not).

If NVCC's fused multiply add optimization is disabled (e.g. by adding --fmad=false to NVCC_FLAGS) then the overlaps calculated by CPU and CUDA agree exactly, and the testcase does not fail (even if PR 2044's change to _create_tensors_with_iou() is reverted).

Disabling fmad without changing devIoU() improves things (reduces failure rate to about 25%), but does not prevent the problem entirely.

NVCC enables fmad by default. It presumably benefits performance.

Disabling NVCC's precise division optimization (--prec-div), did not affect the results at all.

To summarize:

  • small benefit (~12%) just from updating devIoU() to use similar division
  • large benefit (~75%) just from disabling NVCC fmad optimization
  • complete benefit from doing both
  • disabling fmad may affect performance

So maybe...

  • update devIoU() to use similar division as CPU
  • leave fmad to default to enabled
  • leave PR 2044's changes to _create_tensors_with_iou() to force generated data a bit away from the threshold

If that sounds OK, I can put up a PR for the devIoU() change.

@fmassa
Copy link
Member

fmassa commented Apr 7, 2020

@hartb thanks a lot for getting to the bottom of this!

Your proposal sounds good to me, I agree that disabling fmad would not be ideal (we could have unexpected performance hits).
It would be great if you could send the PR to change devIoU.

@hartb
Copy link
Contributor

hartb commented Apr 7, 2020

Submitted: #2072

@fmassa
Copy link
Member

fmassa commented Apr 7, 2020

Thanks @hartb ! This has been fixed now

@fmassa fmassa closed this as completed Apr 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants