Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make CUDA randn work with Zygote #2581

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

bgctw
Copy link

@bgctw bgctw commented Dec 9, 2024

Currently, I get errors when using CUDA in combination with Zygote and random numbers.
mcabbott adviced to add a @non_differentiable CUDA.randn rule for CUDA.randn to CUDAs ChainRulesCoreExt, so that all users can benefit.

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Needs a rebase for the CI failure.

@@ -2,12 +2,15 @@

module ChainRulesCoreExt

using CUDA: CuArray
using CUDA: CuArray, CUDA
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using CUDA: CuArray, CUDA
using CUDA

@@ -0,0 +1,19 @@
using GPUArraysCore: GPUArraysCore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using GPUArraysCore: GPUArraysCore
using GPUArrays

GPUArrays re-exports the Core functionality.

function call_rand(v::AbstractVector{T}) where {T}
randn(T, 4,4) * v[1:4]
end
function call_rand(v::GPUArraysCore.AbstractGPUVector{T}) where {T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function call_rand(v::GPUArraysCore.AbstractGPUVector{T}) where {T}
function call_rand(v::AbstractGPUVector{T}) where {T}


isdefined(Base, :get_extension) ? (import ChainRulesCore) : (import ..ChainRulesCore)

## support ChainRulesCore inplaceability

ChainRulesCore.is_inplaceable_destination(::CuArray) = true

# allow usage of rand with Zygote
ChainRulesCore.@non_differentiable CUDA.randn(::Any...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rand too?

Suggested change
ChainRulesCore.@non_differentiable CUDA.randn(::Any...)
ChainRulesCore.@non_differentiable CUDA.rand(::Any...)
ChainRulesCore.@non_differentiable CUDA.randn(::Any...)

Tab completion says there are a few more, but not marked public, so IDK:

julia> CUDA.rand
rand              rand_logn
rand_logn!        rand_poisson
rand_poisson!     randexp_unlikely
randn             randn_unlikely

@maleadt maleadt force-pushed the master branch 15 times, most recently from 5d585c4 to c850163 Compare December 20, 2024 08:18
@maleadt maleadt marked this pull request as draft January 8, 2025 10:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants