Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scalar indexing of ProjectTo for wrappers of GPU arrays #630

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

DomCRose
Copy link

@DomCRose DomCRose commented Sep 4, 2023

This PR attempts to remove scalar indexing for ProjectTo involving GPUArrays by restricting the depth of wrappers to 1 at all times.

Fixes #624. Additionally, as a first step, I copied many of the standard Array tests and found several other cases where scalar indexing occurred or projected types were incorrect (i.e. nested wrappers rather than plain arrays).

In summary, the changes consist of:

  1. Adding JLArrays as a test dependency and copying / adapting many Array tests to try and find cases that are incorrect / produce scalar indexing (it is certainly possible some cases have been missed, but I've tried to be thorough).
  2. Adding GPUArraysCore as a source dependency so that overloads could be added to ProjectTo. This allows limiting wrapper depth for GPUArrays without hampering CPU performance (I hope).
  3. Adding overloads for projections related to adjoints and transposes of GPUArrays.

While this isn't exactly ideal or elegant, it does enable AD of depth 1 wrappers of GPUArrays to not trigger scalar indexing, bringing it to parity with the forward pass. Hopefully, the additional dependencies are acceptable and I've added them to the project.toml correctly.

@DomCRose
Copy link
Author

DomCRose commented Sep 4, 2023

I seem to have formatted a few lines in the projection.jl tests file by accident. I can revert those changes if needed.

@DomCRose
Copy link
Author

DomCRose commented Sep 4, 2023

So the ChainRules test failure is actually a fixed broken test. I have no idea what is causing Diffractor to fail. From other PRs the failures from ChainRulesOverloadGeneration, StatsFuns and LogExpFunctions seem expected.

@oxinabox
Copy link
Member

oxinabox commented Sep 18, 2023

I would rather have this handled as an extension package that lives in GPUArraysCore.
Adding this directly to CRC adds GPUArraysCore as a dependency to well over 1000 downstream packages.

Can we instead workout what interfaces we would need to expose for GPUArraysCore to hook into?
It looks like the only thing that wasn't just done via adding an overload for was the change to (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
and that could be done with an overload, it would just involve a little more copy paste than ideal

@DomCRose
Copy link
Author

DomCRose commented Sep 18, 2023

Indeed, it could all be done with overloads, I was just trying to avoid the copy-pasting as you say.

Aside from ProjectTo, which is already exported, I think the only things from CRC needed are project_type and _projection_mismatch, a convenience error function. I suppose project_type could be exported, or alternatively, perhaps the getproperty overload could be modified to avoid the need to export, i.e. if the symbol is :project_type then return the first type parameter. Not sure what the most appropriate way to handle the error function would be.

I actually forgot about extensions, and I'm not too familiar with how they work. This does seem like an ideal use case for them. What is the benefit of extending GPUArraysCore rather than ChainRulesCore?

@oxinabox
Copy link
Member

What is the benefit of extending GPUArraysCore rather than ChainRulesCore?

The maintainers of GPUArraysCore are much more familar with what a GPU array represents and what is allowable on it, than the maintainers of ChainRulesCore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ProjectTo causes scalar indexing when taking adjoints of complex CuArray
2 participants