Skip to content

Commit

Permalink
only do minimal change to rule for \ to convert to array
Browse files Browse the repository at this point in the history
Also make second Y not scalar

more coercing some things into arrays some of the time

cleaner def with a helper function
  • Loading branch information
oxinabox committed Jun 2, 2023
1 parent ccd4196 commit fb8cacd
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
59 changes: 51 additions & 8 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,28 +343,71 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R

Y = A \ B

Atf = factorize(A')

function backslash_pullback(ȳ)
= unthunk(ȳ)

Ȳf =
@static if VERSION >= v"1.9"
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
isa AbstractArray ||= [Ȳ]
if !isa(Ȳ, AbstractArray)
Ȳf = [Ȳ]
end
end
Yf = Y
@static if VERSION >= v"1.9"
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(Y, AbstractArray)
Yf = [Y]
end
end
Atf = factorize(A')
#@info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B)
∂A = @thunk begin
= Atf \
= A' \ Ȳf
= -* Y'
= add!!(Ā, ((B - A * Y) *') / Atf)
= add!!(Ā, Atf \ Y * (Ȳ' -'A))
t = (B - A * Y) *'
@static if VERSION >= v"1.9"
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(t, AbstractArray)
t = [t]
end
end
= add!!(Ā, t / A')
= add!!(Ā, A' \ Yf * (Ȳ' -'A))
project_A(Ā)
end
∂B = @thunk project_B(Atf \ )
∂B = @thunk project_B(A' \ Ȳf)
return NoTangent(), ∂A, ∂B
end
return Y, backslash_pullback
end

@static if VERSION >= v"1.9"
# Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358
_maybe_descalar(x) = x isa AbstractArray ? x : [x]
else
_maybe_descalar(x) = x
end

function rrule(A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})

This comment has been minimized.

Copy link
@marius311

marius311 Jun 29, 2023

I'm definitely missing the context of this PR sorry, but just wondering what the intention is with this definition and if its not possibly accidentally missing something? Afaik (A::AbstractVecOrMat{<:Real})(B::AbstractVecOrMat{<:Real}) isn't defined in Julia and it certainly doesn't correspond to A \ B as line 392 suggests no? Ran into this because I do have some (A::CustomArray)(B::CustomArray) defined and this broke that (which I could fix, but when I tracked it down saw this and I can't quite make sense of it). @oxinabox

This comment has been minimized.

Copy link
@oxinabox

oxinabox Jun 29, 2023

Author Member

oh no looks like this got messed up and should never have been defined,
looks like a bad things that got pasted in by mistake

Y = A \ B


function backslash_pullback(ȳ)
= unthunk(ȳ)

∂A = @thunk begin
= A' \ _maybe_descalar(Ȳ)
= -* Y'
+= _maybe_descalar((B - A * Y) *') / A'
+= (A' \ _maybe_descalar(Y)) * (Ȳ' -'A)
(Ā)
end
∂B = @thunk (A' \ _maybe_descalar(Ȳ))
return ∂A, ∂B
end
return Y, backslash_pullback
end

#####
##### `\`, `/` matrix-scalar_rule
#####
Expand Down
4 changes: 2 additions & 2 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@
@testset "Matrix $f Vector" begin
X = randn(10, 4)
y = randn(10)
test_rrule(f, X, y)
test_rrule(f, X, y; check_inferred=false)
end
@testset "Vector $f Matrix" begin
x = randn(10)
Y = randn(10, 4)
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)))
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)), check_inferred=false)
end
else
A = rand(2, 4)
Expand Down

0 comments on commit fb8cacd

Please sign in to comment.