-
Notifications
You must be signed in to change notification settings - Fork 22.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.logcumsumexp
#36308
Add torch.logcumsumexp
#36308
Conversation
💊 CI failures summary and remediationsAs of commit c0ee7ab (more details on the Dr. CI page):
❄️ 1 failure tentatively classified as flakybut reruns have not yet been triggered to confirm: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test (1/1)Step: "Set Up CI Environment After attach_workspace" (full log | diagnosis details | 🔁 rerun) ❄️
|
log-sum-exp trick doesn't seem to be working. The gradient check doesn't pass with log-sum-exp.
I don't think the current implementation is numerically stable. The original issue #26411 suggests using cummax to reduce numerical error, but that has quadratic complexity (see discussion in #32876). |
@anjali411 @albanD Could you please review the PR. Also I don't have the rights to re-run the failed pipeline (whose failure is unrelated to the PR). @tridao I get the point . But I am not sure that I am quite familiar with the codebase. I saw the reference for |
Agreed, those seem to be where cumsum is implemented. Re: CUDA implementation in THC: maybe cumsum will eventually be ported from THC to Aten? I also found cummax implementation in Aten (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/ReduceOpsKernel.cu). One would replace max(x, y) with log_add_exp. The cummax implementation also needs the indices but I don't think logcumsumexp will need those. |
* Add TODO about code duplication. * Fix a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
The perf for the backward might not be great but it is better than nothing.
Feel free to open a new issue if you think the backward should be implemented with a special kernel to make it more efficient. But I think this should be left for a future PR anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
The landing is in progress. But there are some flaky internal tests so I had to re-run these... |
@kshitij12345 thank you very much, great work! |
@kshitij12345 seems like documentation is a bit broken. At least computation formula does not appear at the moment at the master doc (https://pytorch.org/docs/master/generated/torch.logcumsumexp.html#torch.logcumsumexp) |
@agadetsky Thanks for notifying. Will try to get it fixed. |
Summary: References: #24521 #24522 #24547 #24548 #24507 Depends on #36308 Changes related to this PR are only in file : aten/src/ATen/Declarations.cwrap aten/src/ATen/native/cuda/ReduceOpsKernel.cu aten/src/ATen/native/native_functions.yaml aten/src/THC/generic/THCTensorMathScan.cu aten/src/THC/generic/THCTensorMathScan.h Please Review VitalyFedyunin Thanks. Pull Request resolved: #36458 Differential Revision: D21718384 Pulled By: ngimel fbshipit-source-id: 5af15164050c77be164397abd659a48c9ded2b29
Summary: Reference : #36308 (comment) After fix: ![Screenshot from 2020-05-23 15-35-09](https://user-images.githubusercontent.com/19503980/82727956-4bcabb80-9d0b-11ea-85a8-81b35012abbc.png) Pull Request resolved: #38952 Differential Revision: D21722196 Pulled By: ezyang fbshipit-source-id: 62b08c14e0ce9603133841940627df40d7b1e861
Creating new PR as I am unable to push to @pandeykartikey 's branch as I don't have the permissions.
Closes #26411
Based on #32876 Thanks @pandeykartikey for starting this out.
Have addressed the comments.
@anjali411 @agadetsky @albanD