Skip to content

Commit

Permalink
only define cholesky_lower and cholesky_upper rules for ReverseDiff, …
Browse files Browse the repository at this point in the history
…remove rules ChainRules defs
  • Loading branch information
torfjelde committed Aug 7, 2023
1 parent 29790dc commit 3241936
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 52 deletions.
54 changes: 43 additions & 11 deletions ext/BijectorsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ if isdefined(Base, :get_extension)
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular
upper_triangular,
transpose_eager

using Bijectors.LinearAlgebra
using Bijectors.Compat: eachcol
Expand Down Expand Up @@ -77,7 +78,8 @@ else
find_alpha,
pd_from_lower,
lower_triangular,
upper_triangular
upper_triangular,
transpose_eager

using ..Bijectors.LinearAlgebra
using ..Bijectors.Compat: eachcol
Expand Down Expand Up @@ -262,21 +264,51 @@ end
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)


cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X)
@grad function cholesky_lower(X::TrackedMatrix)
X_val = value(X)
y, y_pullback = ChainRulesCore.rrule(cholesky_lower, X_val)
return y, last y_pullback
@grad function cholesky_lower(X_tracked::TrackedMatrix)
X = value(X_tracked)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :L)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_lower_pullback(ΔL)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :L ? ΔL : ΔL'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `lower_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the lower-triangular
# part zeroed out).
return (Δx,)
end

return lower_triangular(parent(C.L)), cholesky_lower_pullback
end

cholesky_upper(X::TrackedMatrix) = track(cholesky_upper, X)
@grad function cholesky_upper(X::TrackedMatrix)
X_val = value(X)
y, y_pullback = ChainRulesCore.rrule(cholesky_upper, X_val)
return y, last y_pullback
@grad function cholesky_upper(X_tracked::TrackedMatrix)
X = value(X_tracked)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :U)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_upper_pullback(ΔU)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :U ? ΔU : ΔU'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `upper_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the upper-triangular
# part zeroed out).
return (Δx,)
end

return upper_triangular(parent(C.U)), cholesky_upper_pullback
end

transpose_eager(X::TrackedMatrix) = track(transpose_eager, X)
@grad function transpose_eager(X_tracked::TrackedMatrix)
X = value(X_tracked)
y, y_pullback = ChainRulesCore.rrule(permutedims, X, (2, 1))
transpose_eager_pullback(Δ) = (y_pullback(Δ)[2],)
return y, transpose_eager_pullback
end

@grad_from_chainrules Bijectors.transpose_eager(X::TrackedMatrix)

if VERSION <= v"1.8.0-DEV.1526"
# HACK: This dispatch does not wrap X in Hermitian before calling cholesky.
Expand Down
41 changes: 0 additions & 41 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@ cholesky_factor(X::LowerTriangular) = X

# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
transpose_eager(X::AbstractMatrix) = permutedims(X)
function ChainRulesCore.rrule(::typeof(transpose_eager), X::AbstractMatrix)
y, y_pullback = ChainRulesCore.rrule(permutedims, X, (2, 1))
function transpose_eager_pullback(Δ)
return (ChainRulesCore.NoTangent(), y_pullback(Δ)[2])
end
return y, transpose_eager_pullback
end

# TODO: Add `check` as an argument?
"""
Expand All @@ -41,23 +34,6 @@ rather than `LowerTriangular`.
but with a custom `ChainRulesCore.rrule` implementation.
"""
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L))
function ChainRulesCore.rrule(::typeof(cholesky_lower), X::AbstractMatrix)
project_to = ChainRulesCore.ProjectTo(X)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :L)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_lower_pullback(_ΔL)
ΔL = ChainRulesCore.unthunk(_ΔL)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :L ? ΔL : ΔL'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `lower_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the lower-triangular
# part zeroed out).
return (ChainRulesCore.NoTangent(), project_to(Δx))
end

return lower_triangular(parent(C.L)), cholesky_lower_pullback
end

"""
cholesky_upper(X)
Expand All @@ -70,23 +46,6 @@ rather than `UpperTriangular`.
but with a custom `ChainRulesCore.rrule` implementation.
"""
cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U))
function ChainRulesCore.rrule(::typeof(cholesky_upper), X::AbstractMatrix)
project_to = ChainRulesCore.ProjectTo(X)
H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :U)
C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false))
function cholesky_upper_pullback(_ΔU)
ΔU = ChainRulesCore.unthunk(_ΔU)
ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :U ? ΔU : ΔU'))
ΔH = cholesky_pullback(ΔC)[2]
Δx = hermitian_pullback(ΔH)[2]
# No need to add pullback for `upper_triangular`, because the pullback
# for `Hermitian` already produces the correct result (i.e. the upper-triangular
# part zeroed out).
return (ChainRulesCore.NoTangent(), project_to(Δx))
end

return upper_triangular(parent(C.U)), cholesky_upper_pullback
end

"""
triu_mask(X::AbstractMatrix, k::Int)
Expand Down

0 comments on commit 3241936

Please sign in to comment.