Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array overloads #131

Merged
merged 34 commits into from
Jun 26, 2024
Merged

Add array overloads #131

merged 34 commits into from
Jun 26, 2024

Conversation

adrhill
Copy link
Owner

@adrhill adrhill commented Jun 19, 2024

Closes #115.
The scope of this PR is somewhat open-ended. What are some must haves for a first array overload PR @gdalle?

@codecov-commenter
Copy link

codecov-commenter commented Jun 19, 2024

Codecov Report

Attention: Patch coverage is 89.93711% with 16 lines in your changes missing coverage. Please review.

Project coverage is 87.45%. Comparing base (34e6bec) to head (981907c).

Files Patch % Lines
src/overloads/arrays.jl 88.07% 13 Missing ⚠️
test/test_arrays.jl 94.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #131      +/-   ##
==========================================
+ Coverage   87.23%   87.45%   +0.21%     
==========================================
  Files          31       33       +2     
  Lines        1363     1522     +159     
==========================================
+ Hits         1189     1331     +142     
- Misses        174      191      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gdalle
Copy link
Collaborator

gdalle commented Jun 20, 2024

Here are the main things I would wanna see:

  • *, mul! (unlike + and -, it can be optimized cause each left-side line and each right-side column is reused many times)
  • dot
  • norm, opnorm
  • \ (I think your idea of using pinv and * is a good first step)
  • inv, pinv
  • det, logdet, logabsdet
  • eigmax, eigmin
  • exp, ^

I think we can disregard special matrix types and factorizations for now.

@adrhill
Copy link
Owner Author

adrhill commented Jun 20, 2024

I think we can disregard special matrix types

Agreed, but there is one exception: lu on SparseMatrixCSC currently causes a stack overflow by calling itself (#108):

lu(A::AbstractSparseMatrixCSC; check::Bool = true) = lu(float(A); check = check)

This is due to us abusing float to return a tracer.

We might just as well fix it in here.

@gdalle
Copy link
Collaborator

gdalle commented Jun 20, 2024

This is due to us abusing float to return a tracer.

I've been wondering about that. What happens if we remove it?

@adrhill
Copy link
Owner Author

adrhill commented Jun 20, 2024

It does indeed look like we might not need float on our own test suite anymore.

However, external functions could call float on a tracer, as seen in the lu call above.

@gdalle
Copy link
Collaborator

gdalle commented Jun 20, 2024

However, external functions could call float on a tracer, as seen in the lu call above.

Indeed let's keep it. My thought process was that it's ethically dubious not to return an AbstractFloat when someone calls float. But then again we don't return Bools either when someone calls < so 🤷

@adrhill
Copy link
Owner Author

adrhill commented Jun 20, 2024

Somewhat interestingly, #108 is now fixed on global tracers, but not on local tracers:

julia> hessian_pattern(x -> logdet(spdiagm(x)), randn(3))
3×3 SparseMatrixCSC{Bool, Int64} with 9 stored entries:
 1  1  1
 1  1  1
 1  1  1

julia> local_hessian_pattern(x -> logdet(spdiagm(x)), randn(3))
ERROR: StackOverflowError:
Stacktrace:
 ...
 [11] float(S::SparseMatrixCSC{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{…}}, Int64})
    @ SparseArrays ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/SparseArrays/src/sparsematrix.jl:983
 [12] lu(A::SparseMatrixCSC{SparseConnectivityTracer.Dual{…}, Int64}; check::Bool) (repeats 28852 times)
    @ SparseArrays.UMFPACK ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/SparseArrays/src/solvers/umfpack.jl:395
 [13] logabsdet(A::SparseMatrixCSC{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{…}}, Int64})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1676
 [14] logdet(A::SparseMatrixCSC{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.HessianTracer{…}}, Int64})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1701

@adrhill
Copy link
Owner Author

adrhill commented Jun 20, 2024

\ (I think your idea of using pinv and * is a good first step)

Any concrete feedback on improving this @gdalle?

I'm planning to make heavy use of FillArrays.jl, e.g.
https://github.com/JuliaArrays/FillArrays.jl/blob/4f8a966e931208a8a5aa56c909cc5e579044e421/src/fillalgebra.jl#L82

@gdalle
Copy link
Collaborator

gdalle commented Jun 20, 2024

Somewhat interestingly, #108 is now fixed on global tracers, but not on local tracers:

That's probably because float(A) expects to create an AbstractMatrix{<:AbstractFloat}. But we don't have Dual{<:AbstractFloat} <: AbstractFloat so it loops endlessly.
I don't have a good solution for this at the moment. Maybe we need some dual array overloads too?

Any concrete feedback on improving this @gdalle?

No, actually since A \ b = pinv(A) * b in most cases (is that true?), I'd say this is close to optimal.

I'm planning to make heavy use of FillArrays.jl

This is a good idea.

@adrhill
Copy link
Owner Author

adrhill commented Jun 21, 2024

Skipping * and dot in this PR, opened #133 to track.

Project.toml Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Show resolved Hide resolved
src/overloads/arrays.jl Show resolved Hide resolved
test/test_arrays.jl Show resolved Hide resolved
test/test_arrays.jl Outdated Show resolved Hide resolved
test/test_arrays.jl Outdated Show resolved Hide resolved
test/test_arrays.jl Show resolved Hide resolved
@gdalle
Copy link
Collaborator

gdalle commented Jun 24, 2024

I'd wait until we settle #135 before merging this one, especially due to the use of myempty

@adrhill adrhill mentioned this pull request Jun 24, 2024
@adrhill adrhill requested a review from gdalle June 24, 2024 15:39
src/overloads/arrays.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Show resolved Hide resolved
src/overloads/arrays.jl Outdated Show resolved Hide resolved
src/overloads/arrays.jl Show resolved Hide resolved
@adrhill adrhill merged commit 9e1c55f into main Jun 26, 2024
5 checks passed
@adrhill adrhill deleted the ah/array-overloads-2 branch June 26, 2024 09:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add AbstractArray{<:AbstractTracer} methods for common LinearAlgebra functions
3 participants