Skip to content

Commit

Permalink
Copy parent to dest in general in copyto! for triangular (JuliaLang#5…
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored and lazarusA committed Jul 12, 2024
1 parent a3cd5b0 commit 5a6b789
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
7 changes: 2 additions & 5 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,7 @@ function copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular)
end
function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular)
copytrito!(dest, parent(U), U isa UpperOrUnitUpperTriangular ? 'U' : 'L')
_triangularize!(U)(dest)
if U isa Union{UnitUpperTriangular, UnitLowerTriangular}
dest[diagind(dest)] .= @view U[diagind(U, IndexCartesian())]
end
copytrito!(dest, U, U isa UpperOrUnitUpperTriangular ? 'L' : 'U')
return dest
end
function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
Expand All @@ -545,7 +542,7 @@ function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:Stride
return dest
end
# for strided matrices, we explicitly loop over the arrays to improve cache locality
# This fuses the copytrito! and triu/l operations
# This fuses the copytrito! for the two halves
function copyto_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
isunit = U isa UnitUpperTriangular
for col in axes(dest,2)
Expand Down
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ Random.seed!(1)
for M in (D, Bu, Bl, Tri, Sym)
@test Matrix(M) == zeros(TypeWithZero, 3, 3)
end

mutable struct MTypeWithZero end
Base.convert(::Type{MTypeWithZero}, ::TypeWithoutZero) = MTypeWithZero()
Base.convert(::Type{MTypeWithZero}, ::TypeWithZero) = MTypeWithZero()
Base.zero(x::MTypeWithZero) = zero(typeof(x))
Base.zero(::Type{MTypeWithZero}) = MTypeWithZero()
U = UpperTriangular(Symmetric(fill(TypeWithoutZero(), 2, 2)))
M = Matrix{MTypeWithZero}(U)
@test all(x -> x isa MTypeWithZero, M)
end
end

Expand Down

0 comments on commit 5a6b789

Please sign in to comment.