-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add array level overloads for *
and dot
#133
Comments
We may also need some more subtle overloads involving julia> using SparseConnectivityTracer, SparseArrays, LinearAlgebra
julia> A = Symmetric(sparse(rand(2, 2)));
julia> f(x) = A * x
f (generic function with 1 method)
julia> jacobian_sparsity(f, rand(2), TracerSparsityDetector())
ERROR: TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}) used in boolean context
Stacktrace:
[1] _mul!(nzrang::typeof(SparseArrays.nzrangeup), diagop::typeof(identity), odiagop::typeof(transpose), C::Vector{…}, A::SparseMatrixCSC{…}, B::Vector{…}, α::SparseConnectivityTracer.GradientTracer{…}, β::SparseConnectivityTracer.GradientTracer{…})
@ SparseArrays ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/linalg.jl:881
[2] spdensemul!(C::Vector{…}, tA::Char, tB::Char, A::SparseMatrixCSC{…}, B::Vector{…}, _add::LinearAlgebra.MulAddMul{…})
@ SparseArrays ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/linalg.jl:50
[3] generic_matvecmul!
@ ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/linalg.jl:35 [inlined]
[4] mul!
@ ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
[5] mul!
@ ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
[6] *(A::Symmetric{Float64, SparseMatrixCSC{…}}, x::Vector{SparseConnectivityTracer.GradientTracer{…}})
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:57
[7] f(x::Vector{SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}})
@ Main ./REPL[25]:1
[8] trace_function(::Type{SparseConnectivityTracer.GradientTracer{…}}, f::typeof(f), x::Vector{Float64})
@ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/interface.jl:33
[9] _jacobian_sparsity(f::Function, x::Vector{Float64}, ::Type{SparseConnectivityTracer.GradientTracer{…}})
@ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/interface.jl:59
[10] jacobian_sparsity(f::Function, x::Vector{…}, ::TracerSparsityDetector{…})
@ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/adtypes.jl:49
[11] top-level scope
@ REPL[26]:1
Some type information was truncated. Use `show(err)` to see complete types. |
This issue is already complex enough for simple Multiplication requires methods for:
We want to support SparseArrays, possibly StaticArrays (#144) and ComponentArrays (#145). And now all of these have to work inside of |
I don't know. I feel like we might want to start advertising local sparsity tracing a bit more for this reason |
I don't want to export / document any internals yet to avoid breaking changes. But once the paper is released and our internals are settled, we should provide some utilities and documentation for user-defined overloads (which users will ideally upstream here). Similar to ChainRules' |
Encourage user-defined overloads is definitely an idea. I remember in ReverseDiff.jl there were some issues regarding the fact that the package define too many functions, and there was no easy answer... |
I'm confident we can find a generic implementation for multiplication of arrays of tracers using The only issue is that Julia methods dispatch based on the most specific "outer type", not our tracer eltype, as discussed here: #144 (comment) As mentioned in the linked comment, this could be solved by sticking our array overloads in a big loop over different array types. Or by calling something like |
But before all of that can be tackled, we first need the methods for |
This big loop over array types is what ReverseDiff does and it is a real can of worms, so if we can avoid it I think we should |
These overloads were not added in #131.
Currently, we make use of the generic
matmul
fallback from LinearAlgebra.Some performance can be gained by adding methods for multiplication of e.g.
since these can be reduced to simple sums of tracers.
The text was updated successfully, but these errors were encountered: