From 5a6b789a1745f7227dadea45a6eec0c0bf8f0d10 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 22 May 2024 13:14:26 +0530 Subject: [PATCH] Copy parent to dest in general in copyto! for triangular (#54529) --- stdlib/LinearAlgebra/src/triangular.jl | 7 ++----- stdlib/LinearAlgebra/test/special.jl | 9 +++++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index ce2fb53498300..47595e06fd47d 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -533,10 +533,7 @@ function copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular) end function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular) copytrito!(dest, parent(U), U isa UpperOrUnitUpperTriangular ? 'U' : 'L') - _triangularize!(U)(dest) - if U isa Union{UnitUpperTriangular, UnitLowerTriangular} - dest[diagind(dest)] .= @view U[diagind(U, IndexCartesian())] - end + copytrito!(dest, U, U isa UpperOrUnitUpperTriangular ? 'L' : 'U') return dest end function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix}) @@ -545,7 +542,7 @@ function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:Stride return dest end # for strided matrices, we explicitly loop over the arrays to improve cache locality -# This fuses the copytrito! and triu/l operations +# This fuses the copytrito! for the two halves function copyto_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix}) isunit = U isa UnitUpperTriangular for col in axes(dest,2) diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index be04fb564a6e8..a78767c68627e 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -128,6 +128,15 @@ Random.seed!(1) for M in (D, Bu, Bl, Tri, Sym) @test Matrix(M) == zeros(TypeWithZero, 3, 3) end + + mutable struct MTypeWithZero end + Base.convert(::Type{MTypeWithZero}, ::TypeWithoutZero) = MTypeWithZero() + Base.convert(::Type{MTypeWithZero}, ::TypeWithZero) = MTypeWithZero() + Base.zero(x::MTypeWithZero) = zero(typeof(x)) + Base.zero(::Type{MTypeWithZero}) = MTypeWithZero() + U = UpperTriangular(Symmetric(fill(TypeWithoutZero(), 2, 2))) + M = Matrix{MTypeWithZero}(U) + @test all(x -> x isa MTypeWithZero, M) end end