-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix scalar indexing of ProjectTo for wrappers of GPU arrays #630
base: main
Are you sure you want to change the base?
Conversation
I seem to have formatted a few lines in the projection.jl tests file by accident. I can revert those changes if needed. |
So the ChainRules test failure is actually a fixed broken test. I have no idea what is causing Diffractor to fail. From other PRs the failures from ChainRulesOverloadGeneration, StatsFuns and LogExpFunctions seem expected. |
I would rather have this handled as an extension package that lives in GPUArraysCore. Can we instead workout what interfaces we would need to expose for GPUArraysCore to hook into? |
Indeed, it could all be done with overloads, I was just trying to avoid the copy-pasting as you say. Aside from ProjectTo, which is already exported, I think the only things from CRC needed are I actually forgot about extensions, and I'm not too familiar with how they work. This does seem like an ideal use case for them. What is the benefit of extending GPUArraysCore rather than ChainRulesCore? |
The maintainers of |
This PR attempts to remove scalar indexing for ProjectTo involving GPUArrays by restricting the depth of wrappers to 1 at all times.
Fixes #624. Additionally, as a first step, I copied many of the standard Array tests and found several other cases where scalar indexing occurred or projected types were incorrect (i.e. nested wrappers rather than plain arrays).
In summary, the changes consist of:
While this isn't exactly ideal or elegant, it does enable AD of depth 1 wrappers of GPUArrays to not trigger scalar indexing, bringing it to parity with the forward pass. Hopefully, the additional dependencies are acceptable and I've added them to the project.toml correctly.