Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array level overloads for * and dot #133

Open
adrhill opened this issue Jun 21, 2024 · 8 comments
Open

Add array level overloads for * and dot #133

adrhill opened this issue Jun 21, 2024 · 8 comments
Labels
array Features regarding array overloads new overloads A new method on tracers is required by a user.

Comments

@adrhill
Copy link
Owner

adrhill commented Jun 21, 2024

These overloads were not added in #131.

Currently, we make use of the generic matmul fallback from LinearAlgebra.
Some performance can be gained by adding methods for multiplication of e.g.

  • matrices of tracers and vector without tracers
  • matrices without tracers and vector of tracers

since these can be reduced to simple sums of tracers.

@adrhill adrhill added the array Features regarding array overloads label Jun 26, 2024
@gdalle gdalle added the bug Something isn't working label Aug 6, 2024
@gdalle
Copy link
Collaborator

gdalle commented Aug 6, 2024

We may also need some more subtle overloads involving Symmetric{SparseMatrixCSC}:

julia> using SparseConnectivityTracer, SparseArrays, LinearAlgebra

julia> A = Symmetric(sparse(rand(2, 2)));

julia> f(x) = A * x
f (generic function with 1 method)

julia> jacobian_sparsity(f, rand(2), TracerSparsityDetector())
ERROR: TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}) used in boolean context
Stacktrace:
  [1] _mul!(nzrang::typeof(SparseArrays.nzrangeup), diagop::typeof(identity), odiagop::typeof(transpose), C::Vector{…}, A::SparseMatrixCSC{…}, B::Vector{…}, α::SparseConnectivityTracer.GradientTracer{…}, β::SparseConnectivityTracer.GradientTracer{…})
    @ SparseArrays ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/linalg.jl:881
  [2] spdensemul!(C::Vector{…}, tA::Char, tB::Char, A::SparseMatrixCSC{…}, B::Vector{…}, _add::LinearAlgebra.MulAddMul{…})
    @ SparseArrays ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/linalg.jl:50
  [3] generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/linalg.jl:35 [inlined]
  [4] mul!
    @ ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
  [5] mul!
    @ ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
  [6] *(A::Symmetric{Float64, SparseMatrixCSC{…}}, x::Vector{SparseConnectivityTracer.GradientTracer{…}})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.4+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:57
  [7] f(x::Vector{SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int64, BitSet}}})
    @ Main ./REPL[25]:1
  [8] trace_function(::Type{SparseConnectivityTracer.GradientTracer{…}}, f::typeof(f), x::Vector{Float64})
    @ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/interface.jl:33
  [9] _jacobian_sparsity(f::Function, x::Vector{Float64}, ::Type{SparseConnectivityTracer.GradientTracer{…}})
    @ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/interface.jl:59
 [10] jacobian_sparsity(f::Function, x::Vector{…}, ::TracerSparsityDetector{…})
    @ SparseConnectivityTracer ~/Work/GitHub/Julia/SparseConnectivityTracer.jl/src/adtypes.jl:49
 [11] top-level scope
    @ REPL[26]:1
Some type information was truncated. Use `show(err)` to see complete types.

@adrhill
Copy link
Owner Author

adrhill commented Aug 6, 2024

We may also need some more subtle overloads involving Symmetric{SparseMatrixCSC}

This issue is already complex enough for simple AbstractMatrix and AbstractVector.

Multiplication requires methods for:

  1. Matrix of tracers * Matrix of tracers
  2. Matrix of tracers * Vector of tracers
  3. Vector transposed of tracers * Matrix of tracers
  4. Vector transposed of tracers * Vector of tracers
  5. Matrix of reals * Matrix of tracers
  6. Matrix of reals * Vector of tracers
  7. Vector transposed of reals * Matrix of tracers
  8. Vector transposed of reals * Vector of tracers
  9. Matrix of tracers * Matrix of reals
  10. Matrix of tracers * Vector of reals
  11. Vector transposed of tracers * Matrix of reals
  12. Vector transposed of tracers * Vector of reals

We want to support SparseArrays, possibly StaticArrays (#144) and ComponentArrays (#145).

And now all of these have to work inside of Diagonal, Symmetric and friends? Where do we draw the line?

@gdalle
Copy link
Collaborator

gdalle commented Aug 6, 2024

I don't know. I feel like we might want to start advertising local sparsity tracing a bit more for this reason

@adrhill
Copy link
Owner Author

adrhill commented Aug 6, 2024

I don't want to export / document any internals yet to avoid breaking changes. But once the paper is released and our internals are settled, we should provide some utilities and documentation for user-defined overloads (which users will ideally upstream here). Similar to ChainRules' frules, this has to turn into a community effort at some point.

@tmigot
Copy link

tmigot commented Aug 6, 2024

Encourage user-defined overloads is definitely an idea. I remember in ReverseDiff.jl there were some issues regarding the fact that the package define too many functions, and there was no easy answer...

@adrhill
Copy link
Owner Author

adrhill commented Aug 6, 2024

I'm confident we can find a generic implementation for multiplication of arrays of tracers using eachrow and eachcol, which all common array types should support.

The only issue is that Julia methods dispatch based on the most specific "outer type", not our tracer eltype, as discussed here: #144 (comment)

As mentioned in the linked comment, this could be solved by sticking our array overloads in a big loop over different array types. Or by calling something like overload_matmul(Symmetric{SparseMatrixCSC}). That would work better with package extensions (needed for #144).

@adrhill
Copy link
Owner Author

adrhill commented Aug 6, 2024

But before all of that can be tackled, we first need the methods for * mentioned in the comment above.

@gdalle
Copy link
Collaborator

gdalle commented Aug 6, 2024

This big loop over array types is what ReverseDiff does and it is a real can of worms, so if we can avoid it I think we should

@adrhill adrhill added new overloads A new method on tracers is required by a user. and removed bug Something isn't working labels Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array Features regarding array overloads new overloads A new method on tracers is required by a user.
Projects
None yet
Development

No branches or pull requests

3 participants