Skip to content

Commit

Permalink
Merge pull request #718 from JuliaDiff/ox/1.9fixes
Browse files Browse the repository at this point in the history
Fix for julia 1.9
  • Loading branch information
oxinabox authored Jun 2, 2023
2 parents 11c230c + fb8cacd commit 50d9d03
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
58 changes: 54 additions & 4 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,20 +342,70 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
project_B = ProjectTo(B)

Y = A \ B

function backslash_pullback(ȳ)
= unthunk(ȳ)

Ȳf =
@static if VERSION >= v"1.9"
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(Ȳ, AbstractArray)
Ȳf = [Ȳ]
end
end
Yf = Y
@static if VERSION >= v"1.9"
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(Y, AbstractArray)
Yf = [Y]
end
end
#@info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B)
∂A = @thunk begin
= A' \
= A' \ Ȳf
= -* Y'
= add!!(Ā, (B - A * Y) *' / A')
= add!!(Ā, A' \ Y * (Ȳ' -'A))
t = (B - A * Y) *'
@static if VERSION >= v"1.9"
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(t, AbstractArray)
t = [t]
end
end
= add!!(Ā, t / A')
= add!!(Ā, A' \ Yf * (Ȳ' -'A))
project_A(Ā)
end
∂B = @thunk project_B(A' \ )
∂B = @thunk project_B(A' \ Ȳf)
return NoTangent(), ∂A, ∂B
end
return Y, backslash_pullback
end

@static if VERSION >= v"1.9"
# Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358
_maybe_descalar(x) = x isa AbstractArray ? x : [x]
else
_maybe_descalar(x) = x
end

function rrule(A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Y = A \ B


function backslash_pullback(ȳ)
= unthunk(ȳ)

∂A = @thunk begin
= A' \ _maybe_descalar(Ȳ)
= -* Y'
+= _maybe_descalar((B - A * Y) *') / A'
+= (A' \ _maybe_descalar(Y)) * (Ȳ' -'A)
(Ā)
end
∂B = @thunk (A' \ _maybe_descalar(Ȳ))
return ∂A, ∂B
end
return Y, backslash_pullback
end

#####
Expand Down
6 changes: 3 additions & 3 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "arraymath.jl" begin
@testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64)
B = generate_well_conditioned_matrix(T, 3)
if VERSION >= v"1.7"
if v"1.7" <= VERSION < v"1.9"
@gpu test_frule(inv, B)
@gpu test_rrule(inv, B)
else
Expand Down Expand Up @@ -167,12 +167,12 @@
@testset "Matrix $f Vector" begin
X = randn(10, 4)
y = randn(10)
test_rrule(f, X, y)
test_rrule(f, X, y; check_inferred=false)
end
@testset "Vector $f Matrix" begin
x = randn(10)
Y = randn(10, 4)
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)))
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)), check_inferred=false)
end
else
A = rand(2, 4)
Expand Down

0 comments on commit 50d9d03

Please sign in to comment.