From 3241936c17cae4426edebd3cde4a13502e861ee5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Aug 2023 07:20:06 +0100 Subject: [PATCH] only define cholesky_lower and cholesky_upper rules for ReverseDiff, remove rules ChainRules defs --- ext/BijectorsReverseDiffExt.jl | 54 +++++++++++++++++++++++++++------- src/utils.jl | 41 -------------------------- 2 files changed, 43 insertions(+), 52 deletions(-) diff --git a/ext/BijectorsReverseDiffExt.jl b/ext/BijectorsReverseDiffExt.jl index 9a008674..3d3d9b3d 100644 --- a/ext/BijectorsReverseDiffExt.jl +++ b/ext/BijectorsReverseDiffExt.jl @@ -36,7 +36,8 @@ if isdefined(Base, :get_extension) find_alpha, pd_from_lower, lower_triangular, - upper_triangular + upper_triangular, + transpose_eager using Bijectors.LinearAlgebra using Bijectors.Compat: eachcol @@ -77,7 +78,8 @@ else find_alpha, pd_from_lower, lower_triangular, - upper_triangular + upper_triangular, + transpose_eager using ..Bijectors.LinearAlgebra using ..Bijectors.Compat: eachcol @@ -262,21 +264,51 @@ end @grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) @grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) + cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X) -@grad function cholesky_lower(X::TrackedMatrix) - X_val = value(X) - y, y_pullback = ChainRulesCore.rrule(cholesky_lower, X_val) - return y, last ∘ y_pullback +@grad function cholesky_lower(X_tracked::TrackedMatrix) + X = value(X_tracked) + H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :L) + C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false)) + function cholesky_lower_pullback(ΔL) + ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :L ? ΔL : ΔL')) + ΔH = cholesky_pullback(ΔC)[2] + Δx = hermitian_pullback(ΔH)[2] + # No need to add pullback for `lower_triangular`, because the pullback + # for `Hermitian` already produces the correct result (i.e. the lower-triangular + # part zeroed out). + return (Δx,) + end + + return lower_triangular(parent(C.L)), cholesky_lower_pullback end cholesky_upper(X::TrackedMatrix) = track(cholesky_upper, X) -@grad function cholesky_upper(X::TrackedMatrix) - X_val = value(X) - y, y_pullback = ChainRulesCore.rrule(cholesky_upper, X_val) - return y, last ∘ y_pullback +@grad function cholesky_upper(X_tracked::TrackedMatrix) + X = value(X_tracked) + H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :U) + C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false)) + function cholesky_upper_pullback(ΔU) + ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :U ? ΔU : ΔU')) + ΔH = cholesky_pullback(ΔC)[2] + Δx = hermitian_pullback(ΔH)[2] + # No need to add pullback for `upper_triangular`, because the pullback + # for `Hermitian` already produces the correct result (i.e. the upper-triangular + # part zeroed out). + return (Δx,) + end + + return upper_triangular(parent(C.U)), cholesky_upper_pullback +end + +transpose_eager(X::TrackedMatrix) = track(transpose_eager, X) +@grad function transpose_eager(X_tracked::TrackedMatrix) + X = value(X_tracked) + y, y_pullback = ChainRulesCore.rrule(permutedims, X, (2, 1)) + transpose_eager_pullback(Δ) = (y_pullback(Δ)[2],) + return y, transpose_eager_pullback end -@grad_from_chainrules Bijectors.transpose_eager(X::TrackedMatrix) if VERSION <= v"1.8.0-DEV.1526" # HACK: This dispatch does not wrap X in Hermitian before calling cholesky. diff --git a/src/utils.jl b/src/utils.jl index aa15ada6..5ec8b0a3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -21,13 +21,6 @@ cholesky_factor(X::LowerTriangular) = X # HACK: Allows us to define custom chain rules while we wait for upstream fixes. transpose_eager(X::AbstractMatrix) = permutedims(X) -function ChainRulesCore.rrule(::typeof(transpose_eager), X::AbstractMatrix) - y, y_pullback = ChainRulesCore.rrule(permutedims, X, (2, 1)) - function transpose_eager_pullback(Δ) - return (ChainRulesCore.NoTangent(), y_pullback(Δ)[2]) - end - return y, transpose_eager_pullback -end # TODO: Add `check` as an argument? """ @@ -41,23 +34,6 @@ rather than `LowerTriangular`. but with a custom `ChainRulesCore.rrule` implementation. """ cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L)) -function ChainRulesCore.rrule(::typeof(cholesky_lower), X::AbstractMatrix) - project_to = ChainRulesCore.ProjectTo(X) - H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :L) - C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false)) - function cholesky_lower_pullback(_ΔL) - ΔL = ChainRulesCore.unthunk(_ΔL) - ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :L ? ΔL : ΔL')) - ΔH = cholesky_pullback(ΔC)[2] - Δx = hermitian_pullback(ΔH)[2] - # No need to add pullback for `lower_triangular`, because the pullback - # for `Hermitian` already produces the correct result (i.e. the lower-triangular - # part zeroed out). - return (ChainRulesCore.NoTangent(), project_to(Δx)) - end - - return lower_triangular(parent(C.L)), cholesky_lower_pullback -end """ cholesky_upper(X) @@ -70,23 +46,6 @@ rather than `UpperTriangular`. but with a custom `ChainRulesCore.rrule` implementation. """ cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U)) -function ChainRulesCore.rrule(::typeof(cholesky_upper), X::AbstractMatrix) - project_to = ChainRulesCore.ProjectTo(X) - H, hermitian_pullback = ChainRulesCore.rrule(Hermitian, X, :U) - C, cholesky_pullback = ChainRulesCore.rrule(cholesky, H, Val(false)) - function cholesky_upper_pullback(_ΔU) - ΔU = ChainRulesCore.unthunk(_ΔU) - ΔC = ChainRulesCore.Tangent{typeof(C)}(; factors=(C.uplo === :U ? ΔU : ΔU')) - ΔH = cholesky_pullback(ΔC)[2] - Δx = hermitian_pullback(ΔH)[2] - # No need to add pullback for `upper_triangular`, because the pullback - # for `Hermitian` already produces the correct result (i.e. the upper-triangular - # part zeroed out). - return (ChainRulesCore.NoTangent(), project_to(Δx)) - end - - return upper_triangular(parent(C.U)), cholesky_upper_pullback -end """ triu_mask(X::AbstractMatrix, k::Int)