Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backpropagation through backsolve calls inv for non-square matrices #46

Open
baggepinnen opened this issue Sep 4, 2019 · 1 comment

Comments

@baggepinnen
Copy link
Contributor

using Flux
using Flux: param, Params, params, gradient
A = param(randn(5,2))
y = param(randn(5))
ls(A,y) = A\y

julia> ls(A,y)
Tracked 2-element Array{Float64,1}:
 -0.3218743888384424
 -3.2253094136376728

julia> gradient(()->sum(ls(A,y)), params(A,y))
ERROR: DimensionMismatch("matrix is not square: dimensions are (5, 2)")
Stacktrace:
 [1] inv(::Array{Float64,2}) at /home/fredrikb/julia-1.3.0-rc1/share/julia/stdlib/v1.3/LinearAlgebra/src/LinearAlgebra.jl:221
 [2] _forward at /home/fredrikb/.julia/packages/Tracker/SAr25/src/lib/array.jl:278 [inlined]
 [3] #track#1 at /home/fredrikb/.julia/packages/Tracker/SAr25/src/Tracker.jl:51 [inlined]
 [4] track at /home/fredrikb/.julia/packages/Tracker/SAr25/src/Tracker.jl:51 [inlined]
 [5] inv at /home/fredrikb/.julia/packages/Tracker/SAr25/src/lib/array.jl:276 [inlined]
 [6] #478 at /home/fredrikb/.julia/packages/Tracker/SAr25/src/lib/array.jl:303 [inlined]
 [7] back_(::Tracker.Call{Tracker.var"##478#479"{TrackedArray{,Array{Float64,2}},TrackedArray{,Array{Float64,1}}},Tuple{Tracker.Tracked{Array{Float64,2}},Tracker.Tracked{Array{Float64,1}}}}, ::Array{Float64,1}, ::Bool) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:35
 [8] back(::Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}, ::Bool) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:58
 [9] foreach at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [10] back_(::Tracker.Call{Tracker.var"##482#483"{TrackedArray{,Array{Float64,1}}},Tuple{Tracker.Tracked{Array{Float64,1}}}}, ::Float64, ::Bool) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:38
 [11] back(::Tracker.Tracked{Float64}, ::Int64, ::Bool) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:58
 [12] #back!#15 at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:77 [inlined]
 [13] #back! at ./none:0 [inlined]
 [14] #back!#32 at /home/fredrikb/.julia/packages/Tracker/SAr25/src/lib/real.jl:16 [inlined]
 [15] back!(::Tracker.TrackedReal{Float64}) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/lib/real.jl:14
 [16] gradient_(::var"##27#28", ::Params) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:4
 [17] #gradient#24(::Bool, ::typeof(gradient), ::Function, ::Params) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:164
 [18] gradient(::Function, ::Params) at /home/fredrikb/.julia/packages/Tracker/SAr25/src/back.jl:164
 [19] top-level scope at REPL[68]:1
@baggepinnen
Copy link
Contributor Author

One possible workaround for some situations:

ls(A,y) = (A'A)\(A'y)
ls(A,y)
gradient(()->sum(ls(A,y)), params(A,y))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant