Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoint of cholesky is hard-coded for the CPU #1210

Closed
Red-Portal opened this issue Apr 23, 2022 · 15 comments
Closed

Adjoint of cholesky is hard-coded for the CPU #1210

Red-Portal opened this issue Apr 23, 2022 · 15 comments

Comments

@Red-Portal
Copy link

Hi,

I've been attempting to differentiate through a Cholesky decomposition, which is common practice in Gaussian processes. The problem is that, the current adjoint for the Cholesky is hard-coded for the CPU version of trsm!.

See the following minimal working example:

using CUDA
using KernelAbstractions
using CUDAKernels
using LinearAlgebra

import Tullio
import Zygote

function main()
    N = 1024
    D = 16
    X = randn(Float32, D, N)
    y = randn(Float32, N)

    CUDA.allowscalar(true)
    X_dev = CuArray(X)
    y_dev = CuArray(y)
    @time begin
        ∇K = Zygote.gradient(cu(randn(Float32, D+2))) do θ
            ℓα     = θ[1:1]
            ℓϵ     = θ[2]
            logℓ   = θ[3:end]
            Tullio.@tullio K[i,j] := exp(ℓα[1]*2 - (X_dev[k,i] - X_dev[k,j])^2 / exp(2*logℓ[k])) verbose=true
            K_ϵ      = K + cu(exp(ℓϵ)*I)
            K_ϵ_chol = cholesky(K_ϵ)
            α        = K_ϵ_chol \ y_dev
            dot(α, y_dev)
        end
    end
end

main()

output:

ERROR: ArgumentError: cannot take the CPU address of a CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
Stacktrace:
  [1] unsafe_convert(#unused#::Type{Ptr{Float32}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ CUDA ~/.julia/packages/CUDA/5jdFl/src/array.jl:315
  [2] trsm!(side::Char, uplo::Char, transa::Char, diag::Char, alpha::Float32, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ LinearAlgebra.BLAS /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/blas.jl:1958
  [3] (::Zygote.var"#817#818"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Cholesky{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::NamedTuple{(:uplo, :info, :factors), Tuple{Nothing, Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:603
  [4] (::Zygote.var"#3217#back#819"{Zygote.var"#817#818"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Cholesky{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}})(Δ::NamedTuple{(:uplo, :info, :factors), Tuple{Nothing, Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] Pullback
    @ ./REPL[33]:17 [inlined]
  [6] (::typeof((λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [7] (::Zygote.var"#56#57"{typeof((λ))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
  [8] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76

A simple fix is to use the following snippete:

@eval Zygote begin
    import CUDA
    @adjoint function cholesky::CUDA.CuArray; check = true)
        C = cholesky(Σ, check = check)
        C, function::NamedTuple)
            issuccess(C) || throw(PosDefException(C.info))
            U, Ū = C.U, Δ.factors

            U_tru = triu(U.data)
            Ū_tru = triu.data)

            Σ̄ = similar(U.data)
            Σ̄ = mul!(Σ̄, Ū_tru, U_tru')
            Σ̄ = copytri!(Σ̄, 'U')
            Σ̄ = ldiv!(U, Σ̄)
            Σ̄ = CUDA.CUBLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
            Σ̄[diagind(Σ̄)] ./= 2
            return (UpperTriangular(Σ̄),)
        end
    end
end

The two calls to triu are necessary for going around a performance bug in the matrix multiplication between two triangular matrices. I didn't pursue the cause further, but it seems that multiplying two triangular matrices on the GPU is like a 100 times slower than a simple matrix multiplication. Any thoughts on the reason for this?

@CarloLucibello
Copy link
Member

The rrule in ChainRules.jl is also hardcoded on cpu so removing the @adjoint from Zygote won't fix the problem.
Why CUDA.jl defines its own trsm! instead of overloading Base's one?

@CarloLucibello
Copy link
Member

The two calls to triu are necessary for going around a performance bug in the matrix multiplication between two triangular matrices. I didn't pursue the cause further, but it seems that multiplying two triangular matrices on the GPU is like a 100 times slower than a simple matrix multiplication. Any thoughts on the reason for this?

This seems worth opening an issue in CUDA.jl

@Red-Portal
Copy link
Author

Hi @CarloLucibello ,

The rrule in ChainRules.jl is also hardcoded on cpu so removing the @adjoint from Zygote won't fix the problem.
Why CUDA.jl defines its own trsm! instead of overloading Base's one?

Does Base have trsm! though? The offending trsm! is part of BLAS not LinearAlgebra, and it seems CUDA.jl is overleading most of the routines in LinearAlgebra, just not BLAS.

@ToucheSir
Copy link
Member

It appears certain functions in LinearAlgebra do call trsm! though, so if you can update the rules to work with those there may be a better chance of getting GPU compatibility.

@Red-Portal
Copy link
Author

Red-Portal commented Apr 27, 2022

@ToucheSir You're suggesting to call the LinearAlgebra wrappers around trsm! right? That seems like a fair suggestion. I'll try to deal sort out the issues around UpperTriangular/LowerTriangular with the CUDA.jl folks and see if I can do something about the rrules.

@Red-Portal
Copy link
Author

Hi, @ToucheSir I've checked whether I could quickly fix this issue, and it seems that the choice of invoking BLAS.trsm! was out of unfortunate necessity. The trsm! routine is currently used to perform an in-place right-to-left triangular matrix-dense matrix system. The corresponding LinearAlgebra routine is rdiv!, but cannot be applied here because it doesn't support triangular matrix-dense matrix stating that this is computationally inefficient. I don't think there is an elegant solution (without unnecessary allocations) at the moment without adding a triangular matrix-dense matrix rdiv! specialization that simply calls trsm! under the hood, but that will have to happen in the main Julia repo, which I'm not sure if this is the best way to solve this issue.

@ToucheSir @devmotion

@ToucheSir
Copy link
Member

If you can come up with a solution that doesn't require any additional dependencies (excluding CUDA.jl itself), we could add it to a block like

@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
. Far from ideal, but in the absence of an upstream interface it's probably the best we can do.

@Red-Portal
Copy link
Author

I personally don't understand why we don't deserve the triangular matrix-dense matrix specialization; after all, BLAS supports it. So unless someone gives me a compelling reason why that shouldn't exist, I think upstream (and us...) would be better off with it.

@ToucheSir
Copy link
Member

You'll have to bring that up on the CUDA.jl side, but my understanding is that they're not comfortable making CUDA.CUBLAS just a bunch of specializations of methods on LinearAlgebra.BLAS because of differences in behaviour and non-overlapping API surface area.

@Red-Portal
Copy link
Author

Oh I'm talking about the actual upstream LinearAlgebra.jl not CUDA.jl.

@ToucheSir
Copy link
Member

Well you've lost me, which indicates this should probably be taken as an upstream issue :)

@devmotion
Copy link
Collaborator

devmotion commented Jun 8, 2022

I guess one possible solution, if neither the Zygote nor the CR rules work with CUDA here, could also be to add a CR definition to CUDA for cholesky(::CuArray) and remove the adjoint in Zygote (so that Zygote does actually uses the CR implementations)?

Of course, it would be nice if the definition in CR would also work (efficiently) on GPUs but I guess it's unavoidable that sometimes one has to specialize on the array type.

@devmotion
Copy link
Collaborator

Can you recheck @Red-Portal? This should be fixed on the master branch now that #1114 was merged.

@Red-Portal
Copy link
Author

Uh, I'm still experiencing some issues. I'm unavailable next week, so I'll take a deeper look after that.

@Red-Portal
Copy link
Author

@devmotion I checked again and seems to work. Don't know why I didn't get it right last time. Regardless, LGTM. Cheers for everyone who made this possible!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants