Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme fails to reverse mode AD through broadcast assignment with CuArrays #2116

Closed
simenhu opened this issue Nov 22, 2024 · 3 comments · Fixed by JuliaGPU/CUDA.jl#2563
Closed

Comments

@simenhu
Copy link

simenhu commented Nov 22, 2024

Origin of the problem is trying to use Enzyme.jl in the EnzymeJVP() in SciMLSensitivities. After boiling it down to a small example it seems that Enzyme has problems compiling the code for reverse mode AD through the .= operator in the system equation. The code works if you switch the CuArray type with Array on line 7.

Code to reproduce error:

using OrdinaryDiffEq
using Plots
using Enzyme
using CUDA
using GPUArraysCore

gpu = CuArray

# Define the mass-damper system
function mass_damper!(du, u, p, t)
    sys = p
    du .= sys*u
    nothing
end

# Initial conditions: position and velocity
u0 = gpu([1.0; 0.0])

# Parameters: mass, spring and damping coefficient
p = (mass=1.0, spring=1.0, dampening=0.5)
sys = gpu([0 1.0; -p.spring/p.mass -p.dampening/p.mass])

# Time span for the simulation
tspan = (0.0, 50.0)

# Define the problem
prob = ODEProblem{true}(mass_damper!, u0, tspan, sys)

# Solve the problem
sol = solve(prob, Tsit5())

# Plot the solution
GPUArraysCore.@allowscalar display(plot(sol))

################## To check the EnzymeVJP() separately ###################
function check_enzyme_VJP(prob)
    u0 = prob.u0
    p = prob.p
    tmp2 = Enzyme.make_zero(p)
    t = prob.tspan[1]
    du = zero(u0)

    if DiffEqBase.isinplace(prob)
        _f = prob.f
    else
        _f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
    end

    _tmp6 = Enzyme.make_zero(_f)
    tmp3 = zero(u0)
    tmp4 = zero(u0)
    ytmp = zero(u0)
    tmp1 = zero(u0)

    Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
        Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
        Enzyme.Duplicated(ytmp, tmp1),
        Enzyme.Duplicated(p, tmp2),
        Enzyme.Const(t))
end

@info "Checking Enzyme VJP..."
prob = ODEProblem{true}(mass_damper!, u0, tspan, sys)
check_enzyme_VJP(prob)

Error message:

ERROR: Enzyme.Compiler.EnzymeNoDerivativeError(Cstring(0x00007f5ee76d6553))
Stacktrace:
  [1] launch_configuration
    @ ~/.julia/packages/CUDA/2kjXI/lib/cudadrv/occupancy.jl:56 [inlined]
  [2] #launch_heuristic#1200
    @ ~/.julia/packages/CUDA/2kjXI/src/gpuarrays.jl:22 [inlined]
  [3] launch_heuristic
    @ ~/.julia/packages/CUDA/2kjXI/src/gpuarrays.jl:15 [inlined]
  [4] _copyto!
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:78 [inlined]
  [5] materialize!
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:38 [inlined]
  [6] materialize!
    @ ./broadcast.jl:911 [inlined]
  [7] mass_damper!
    @ ~/programming/birdsview/julia_FDTD_debuging/mass_damper_Enzyme_julia_MWE.jl:12
  [8] ODEFunction
    @ ~/.julia/packages/SciMLBase/NtgCQ/src/scimlfunctions.jl:2358 [inlined]
  [9] ODEFunction
    @ ~/.julia/packages/SciMLBase/NtgCQ/src/scimlfunctions.jl:0 [inlined]
 [10] diffejulia_ODEFunction_9303_inner_1wrap
    @ ~/.julia/packages/SciMLBase/NtgCQ/src/scimlfunctions.jl:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8398 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7950 [inlined]
 [13] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7723 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:491 [inlined]
 [15] check_enzyme_VJP(prob::ODEProblem{…})
    @ Main ~/programming/birdsview/julia_FDTD_debuging/mass_damper_Enzyme_julia_MWE.jl:55
 [16] top-level scope
    @ ~/programming/birdsview/julia_FDTD_debuging/mass_damper_Enzyme_julia_MWE.jl:64
Some type information was truncated. Use `show(err)` to see complete types.

Versioninfo() of the system:


Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 256 × AMD EPYC 7H12 64-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 30 default, 0 interactive, 15 GC (on 256 virtual cores)
Environment:
  LD_LIBRARY_PATH = /usr/local/cuda-11.8/lib64
  JULIA_SSL_NO_VERIFY_HOSTS = github.com, julialang-s3.julialang.org, pkg.julialang.org
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 30

Packages are:

  [052768ef] CUDA v5.5.2
  [7da242da] Enzyme v0.13.16
⌅ [46192b85] GPUArraysCore v0.1.6
  [1dea7af3] OrdinaryDiffEq v6.90.1
  [91a5bcdd] Plots v1.40.9
@wsmoses
Copy link
Member

wsmoses commented Nov 24, 2024

I think JuliaGPU/CUDA.jl#2563 should resolve, give it a go?

@wsmoses wsmoses closed this as completed Nov 24, 2024
@simenhu
Copy link
Author

simenhu commented Nov 25, 2024

Hi, thanks for the rapid reply @wsmoses! I tried the fix, upgrading to CUDA=5.5.2 and Enzyme=0.13.16 and it seems like the launch_configuration is still captured by Enzyme. I tried to copy your fix directly into the script above to ensure that the rule is used by Enzyme:

using OrdinaryDiffEq
using Plots
using Enzyme
using CUDA
using EnzymeCore

function EnzymeCore.EnzymeRules.inactive_noinl(::typeof(CUDA.launch_configuration), args...; kwargs...)
    return nothing
end

gpu = CuArray

# Define the mass-damper system
function mass_damper!(du, u, p, t)
    sys = p
    du .= sys*u
    nothing
end

etc...

But it still gives me this error, which I suspect that it doesn't ignores it fully?

ERROR: Enzyme.Compiler.EnzymeNoDerivativeError(Cstring(0x00007f08a90ca4e5))
Stacktrace:
  [1] launch_configuration
    @ ~/.julia/packages/CUDA/2kjXI/lib/cudadrv/occupancy.jl:56 [inlined]
  [2] #launch_heuristic#1200
    @ ~/.julia/packages/CUDA/2kjXI/src/gpuarrays.jl:22 [inlined]
  [3] launch_heuristic
    @ ~/.julia/packages/CUDA/2kjXI/src/gpuarrays.jl:15 [inlined]
  [4] _copyto!
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:78 [inlined]
  [5] materialize!
    @ ~/.julia/packages/GPUArrays/qt4ax/src/host/broadcast.jl:38 [inlined]
  [6] materialize!
    @ ./broadcast.jl:911 [inlined]
  [7] mass_damper!
    @ ~/programming/birdsview/julia_FDTD_debuging/mass_damper_Enzyme_julia_MWE.jl:16
  [8] ODEFunction
    @ ~/.julia/packages/SciMLBase/ZyZAV/src/scimlfunctions.jl:2358 [inlined]
  [9] ODEFunction
    @ ~/.julia/packages/SciMLBase/ZyZAV/src/scimlfunctions.jl:0 [inlined]
 [10] diffejulia_ODEFunction_8149_inner_1wrap
    @ ~/.julia/packages/SciMLBase/ZyZAV/src/scimlfunctions.jl:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8398 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7950 [inlined]
 [13] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7723 [inlined]
 [14] autodiff
    @ ~/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:491 [inlined]
 [15] check_enzyme_VJP(prob::ODEProblem{…})
    @ Main ~/programming/birdsview/julia_FDTD_debuging/mass_damper_Enzyme_julia_MWE.jl:60
 [16] top-level scope
    @ ~/programming/birdsview/julia_FDTD_debuging/mass_damper_Enzyme_julia_MWE.jl:69
Some type information was truncated. Use `show(err)` to see complete types.

What do you think? Which version of julia are you using? I'm using 1.10.6 atm. after seeing some issues reported on the combination Julia=1.11.x and CUDA/Enzyme...

@simenhu
Copy link
Author

simenhu commented Nov 26, 2024

I switched out the rule you defined with the one that prevent inlining:

function EnzymeCore.EnzymeRules.inactive(::typeof(CUDA.launch_configuration), args...; kwargs...)
    return nothing
end

This seems to work! I dont know why it doesn't work with the version that allows inlining tho? Do you have any idea?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants