-
Notifications
You must be signed in to change notification settings - Fork 45
test_sum can fail with unstable summation #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
IIRC, PyTorch does implement a stable algorithm as well, but I don't know how fine does this algorithm need to be. cc @peterbell10 who I think implemented this algorithm. For how to solve the PyTorch problem, two approaches come to mind. The first one is to implement a fancier algorithm. I discussed the one proposed in this recent paper https://www.youtube.com/watch?v=B1eFGn5nN84 with Peter last year. The simpler solution would be to simply do the accumulation in int64_t, and just cast down to float32_t at the end of the full computation. I reckon the perf hit of this last approach would be minimal, and we would get all of the gains in most cases |
As for the particular case... I don't think this should be tested. If you are making a case for testing this, you'll be able to make similar cases for all kinds of weird inputs in the |
In PyTorch we use a cascading summation which is much more stable than the naive summation. IIRC NumPy uses the exact same algorithm in fact. The only difference is that we don't promote float to double before doing the summation. IMO the problem here is the error metric. You just can't expect a floating point summation to survive large cancellations orders of magnitude higher than the result. Promoting to double or using Kahan summation raises the acceptable magnitude difference, but can never eliminate the problem entirely. Instead I would suggest
which captures the error relative to the order of magnitude of the summands. In which case NumPy's error is 0.0 and PyTorch's error is 3e-8 which is around the float epsilon. |
I'm sorry if it wasn't clear, but I'm making a case that this shouldn't be tested, because the spec doesn't require any such thing. Using a coarser error metric is actually a better idea than trying to limit the input data. That way it will work for any input. We need to support float64 as well, but presumably we could replace a cast to float64 with an exact calculation with Fraction. |
This isn't specific to pytorch. The same problem can happen with NumPy if you run enough examples:
We really need to just make this test less susceptible to loss of significance issues. |
I've also been seeing errors with
I've seen test_prod fail for torch in other runs as well. I believe the underlying issue is the same. The test is assuming that underflow doesn't occur but that's not a reasonable assumption for real-world implementations, especially in lower precision. |
Looks like the new cumulative_sum test has the same issue
I'd rather not skip this test as it's for a new function so we will likely need to figure out this issue. |
… sum/cumulative_sum/prod This isn't completely rigorous (I haven't tweaked the tolerances used in isclose from the generous ones we were using before), but I haven't gotten hypothesis to find any bad corner cases for this yet. If any crop up we can easily tweak the values. Fixes data-apis#168
This test fails with pytorch's sum() (note you need to use
Note that you need to use data-apis/array-api-compat#14 because torch.sum has other issues with its signature and type promotion that will fail the test earlier.
torch truncates the float values, but here is the full tensor:
Note that this is float32, the torch default dtype. The problem is:
So if you add the elements in the wrong order, you get the wrong sum (5 instead of 7).
NumPy's sum does this correctly. I don't know if we run the tests on NumPy for long enough if it will also come up with a similar situation. It's really hard to do summation stably. My guess is that the only reason NumPy passes is because NumPy happens to do exactly what we are doing in the test, i.e., upcast to float64 (i.e.,
float
), and sum in C-order. But PyTorch's default dtype is float32, so it isn't required to return a float64 from sum(dtype=None) and therefore doesn't upcast internally.The spec doesn't require
sum
to use a stable algorithm (like Kahan) vs. naive term-wise summation, and doesn't require the terms to be summed in any particular order. So we should avoid generating examples that can lead to loss of significance. This might require some research, but I think it should be sufficient to avoid generating examples that are too far from each other in magnitude. Another idea would be to compute the exact sum (e.g., usingFraction
) and compare it to a naive summation. One problem is that the amount of loss of significance depends on the order of summation, and there's no guarantee in which order the library will sum the array.Also, even if we do this, it's problematic to check the summation by first upcasting to
float
when the input is float32, because that will inherently produce a more accurate result than if the algorithm worked directly on the float32.The text was updated successfully, but these errors were encountered: