Skip to content

Commit

Permalink
Merge branch 'master' into index-traits
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokazama authored Jun 13, 2023
2 parents 9ac437d + 796eb67 commit 116da7d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.5.0"
version = "7.4.8"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
24 changes: 16 additions & 8 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,45 +468,53 @@ Returns the number.
"""
bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false)

@static if VERSION < v"1.7beta"
const DEFAULT_CHOLESKY_PIVOT = Val(false)
else
const DEFAULT_CHOLESKY_PIVOT = LinearAlgebra.NoPivot()
end

"""
cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance
Returns an instance of the Cholesky factorization object with the correct type
cheaply.
"""
function cholesky_instance(A::Matrix{T}, pivot = LinearAlgebra.RowMaximum()) where {T}
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
return cholesky(similar(A, 0, 0), pivot, check = false)
end
function cholesky_instance(A::SparseMatrixCSC, pivot = LinearAlgebra.RowMaximum())

function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
cholesky(sparse(similar(A, 1, 1)), check = false)
end


"""
cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) -> a
Returns the number.
"""
cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) = a
cholesky_instance(a::Number, pivot = DEFAULT_CHOLESKY_PIVOT) = a

"""
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false)
Slow fallback which gets the instance via factorization. Should get
specialized for new matrix types.
"""
cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) = cholesky(a, pivot, check = false)
cholesky_instance(a::Any, pivot = DEFAULT_CHOLESKY_PIVOT) = cholesky(a, pivot, check = false)

"""
ldlt_instance(A) -> ldlt_factorization_instance
Returns an instance of the LDLT factorization object with the correct type
cheaply.
"""
function ldlt_instance(A::Matrix{T}) where {T}
return ldlt(SymTridiagonal(similar(A, 0, 0)), check = false)
function ldlt_instance(A::Matrix{T}) where {T}
return ldlt(SymTridiagonal(similar(A, 0, 0)))
end
function ldlt_instance(A::SparseMatrixCSC)
ldlt(sparse(similar(A, 1, 1)), check = false)
ldlt(sparse(similar(A, 1, 1)), check=false)
end

"""
Expand Down Expand Up @@ -574,7 +582,7 @@ Returns an instance of the QR factorization object with the correct type
cheaply.
"""
function qr_instance(A::Matrix{T}) where {T}
LinearAlgebra.QRCompactWYQ(zeros(T,0,0),zeros(T,0,0))
LinearAlgebra.QRCompactWY(zeros(T,0,0),zeros(T,0,0))
end

function qr_instance(A::Matrix{BigFloat})
Expand Down
16 changes: 12 additions & 4 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,21 @@ end
@test ArrayInterface.qr_instance(A) isa typeof(qr(A))

if !(eltype(A) <: BigFloat)
@test ArrayInterface.bunchkaufman_instance(A) isa typeof(bunchkaufman(A' * A))
@test ArrayInterface.cholesky_instance(A) isa typeof(cholesky(A' * A))
@test ArrayInterface.ldlt_instance(A) isa typeof(ldlt(SymTridiagonal(A' * A)))
@test ArrayInterface.bunchkaufman_instance(A' * A) isa typeof(bunchkaufman(A' * A))
@test ArrayInterface.cholesky_instance(A' * A) isa typeof(cholesky(A' * A))
@test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A)))
@test ArrayInterface.svd_instance(A) isa typeof(svd(A))
end
end

for A in [sparse([1.0 2.0; 3.0 4.0])]
@test ArrayInterface.lu_instance(A) isa typeof(lu(A))
@test ArrayInterface.qr_instance(A) isa typeof(qr(A))
if VERSION >= v"1.9-"
@test ArrayInterface.cholesky_instance(A' * A) isa typeof(cholesky(A' * A))
end
@test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A)))
end
end

@testset "known values" begin
Expand Down Expand Up @@ -322,4 +331,3 @@ end
@test @inferred(ArrayInterface.known_length(CartesianIndex(1, 2, 3))) === 3
@test @inferred(ArrayInterface.known_length((x = 1, y = 2))) === 2
end

0 comments on commit 116da7d

Please sign in to comment.