-
Notifications
You must be signed in to change notification settings - Fork 22.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce overheads on several CPU kernels by avoiding restrides. #36875
Conversation
💊 Build failures summary and remediationsAs of commit 9405b4f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 13 times. |
I believe the CI failures are unrelated. clang-format is complaining about I did a very long (2 hour, 200 interleaved env switches) with the following result:
Master was built from the commit that the this PR branched from, so meaningful differences on the reference tasks can't be chalked up to other diffs between the code. Looking at the distributions for This is true even when I only run |
Ugg. Bad rebase. Sorry all. |
b2086a4
to
4559b68
Compare
Alright! Fast forward resolved the CI failures. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Interesting what's happening with max(x,y) and why it changes, but it's all good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robieta has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes a safety issue (Nonsense values and segfaults) introduced by #36875 when in-place gather tries to use incorrect shapes. Consider the following block of code: ``` k0 = 8 k1 = 8 m = 100 x = torch.rand((k0, k1)) ind = torch.randint(0, k0, (m, k1)) output = torch.empty((m, k1)) print(torch.gather(x, 0, ind, out=output)) print(torch.gather(x, 1, ind, out=output)) ``` The first gather is legal, the second is not. (`ind` and `output` need to be transposed) Previously this was caught when the kernel tried to restride inputs for TensorIterator, but we can no longer rely on those checks and must test explicitly. If `m` is small the second gather returns gibberish; if it is large enough to push the read out of memory block the program segfaults. Pull Request resolved: #37102 Differential Revision: D21190580 Pulled By: robieta fbshipit-source-id: 80175620d24ad3380d78995f7ec7dbf2627d2998
Calling
t.as_strided(..., ...)
must make a newTensorImpl
to back the new tensor, which takes 300-400 ns. Reduction, scatter/gather, and comparison kernels currently restride inputs and outputs in order to handledim
inside the function passed to TensorIterator. Because these Tensors are created solely for consumption by the iterator a full restride and metadata copy is surplus to requirements. Moreover, shapes are already checked by these kernels prior to callingadd_input
andadd_output
, so shape inference and broadcasting are also unnecessary.This PR adds a
TensorIterator::declare_static_shape(...)
method, which allows certain kernels to use a much more constrained and efficient shape path. This results in a 900-1200 ns speedup forgather / scatter / scatter_add / cumsum / cumprod
and a 250-500 ns speedup for elementwisemin
andmax
.Measurements were taken with this python script, which is driven by this bash script. The general procedure for mitigating environmental skew is to repeatedly switch between an environment which is built with master and one which is built with this branch while running the python script. Within the python measurement script the following was used to reduce variation:
Two independent end-to-end runs are included since there is some variation even with the above measures. Overall measurement error seems to be about +/- 100 ns.
The benchmark also includes several tasks which are not affected by this PR, both to check for a degradation in TensorIterator performance when static shapes are not set (which did happen for an earlier iteration of this optimization) and to estimate measurement variability and validate that measured improvements are significant.
First run:
Second run:
CC @ilia-cher @VitalyFedyunin @glaringlee @gdankel