-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Increase performance for conversions including axis angles #1948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @alex-bene! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
@bottler let's continue the conversation here. |
Yes let's add the "fast" to both cases. And I don't mind the eps being inconsistent: actually feel free to add it everywhere or in these two places or nowhere. |
Okay, also, one final question @bottler . On the implementation of def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions could become def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
sin_half_angles_over_angles = 0.5 * torch.sinc(0.5 * angles / torch.pi)
return torch.cat(
[torch.cos(0.5 * angles), axis_angle * sin_half_angles_over_angles], dim=-1
) and this also avoids setting the From my testing, there does not seem to be a downside to using import math
for aprx in range(5):
sinx_div_x_taylor = lambda x: sum([
((-1 if n%2 == 1 else 1) / math.factorial(n)) * (x**(2*n)) for n in range(aprx)
])
res = torch.tensor([
sinx_div_x_taylor(i) - torch.sinc(torch.tensor(i)/torch.pi).item()
for i in np.linspace(0, 1e-6, 1000)
])
print(f"taylor approximation: {aprx} - max error: {res.abs().max()} - mean error: {res.abs().mean()}") which outputs: taylor approximation: 0 - max error: 1.0 - mean error: 1.0
taylor approximation: 1 - max error: 1.66644475996236e-13 - mean error: 5.558120630411167e-14
taylor approximation: 2 - max error: 8.333334022836425e-13 - mean error: 2.779195762414588e-13
taylor approximation: 3 - max error: 8.333334022836425e-13 - mean error: 2.779195762414588e-13
taylor approximation: 4 - max error: 8.333334022836425e-13 - mean error: 2.779195762414588e-13 |
Yes please! I think the change to torch.sinc would be a good idea. It wasn't tried before. Thanks! |
Prefer torch.special.sinc to torch.sinc . |
import torch
naive = lambda x: torch.sin(x) / x if x != 0 else 1
res = torch.tensor([
naive(i) - torch.sinc(i/torch.pi).item()
for i in torch.linspace(0, 1e-6, 100000)
])
assert res.abs().max() == 0.0
assert res.abs().mean() == 0.0 |
So the whole taylor series thing probably wasn't useful at all! |
…and handle edge case in axis_angle <-> rot matrix conv
It seems so So, I spent quite a lot of time to figure out what to do in the conversion In total, I feel quite confident to set |
In which case, is this change to that function actually a speed improvement? I suggest either not bothering, or, I guess, doing something less precise and acknowledging that in the docstring. I think the thing this PR is really trying to do is speed up a case where you know have a simple speedup(s) with an approach already thought to be accurate enough for your application. That might only be true in one direction. |
Well, considering a uniform distribution of angles, this should still be a speed improvement (unless all angles were near pi for example). However, calling the functions for the intermediate conversion even with empty tables led to very slow performance (not sure why). So, I just committed a change to avoid this. Now there is a total improvement in speed across all cases plus the function is stable in all edge cases. |
@bottler has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hey @bottler , I can see that some tests are falling, however, I do not have access to see exactly what fails. Is it possible to share the error logs to fix this? |
Thanks for this work! Tests are fine. Probably you saw some reflection of longstanding failures. |
Okay, I see. Glad I could help. 🙌 |
Summary: A continuation of #1948 -- this commit fixes a small numerical issue with `matrix_to_axis_angle(..., fast=True)` near `pi`. bottler feel free to check this out, it's a single-line change. Pull Request resolved: #1953 Reviewed By: MichaelRamamonjisoa Differential Revision: D70088251 Pulled By: bottler fbshipit-source-id: 54cc7f946283db700cec2cd5575cf918456b7f32
This is an extension of #1544 with various speed, stability, and readability improvements. (I could not find a way to make a commit to the existing PR). This PR is still based on the Rodrigues' rotation formula.
The motivation is the same; this change speeds up the conversions up to 10x, depending on the device, batch size, etc.
Notes
π
, the existing implementation and the proposed one start to differ. However, (my understanding is that) this is not a problem as the axis can not be stably inferred from the rotation matrix in this case in general.