From f9c10d2f25213d1376678a11230ef71274cd0281 Mon Sep 17 00:00:00 2001 From: Yuto Horikawa Date: Sun, 14 Aug 2022 19:26:01 +0900 Subject: [PATCH] Fix `_solve` with nonisbits type (#1071) * fix indents * add BigFloat in solve test * add support for non-isbits type in _solve * bump version to v1.5.3 --- Project.toml | 2 +- src/solve.jl | 31 +++++++++++++++++++------------ test/solve.jl | 4 ++-- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 673df031..f0bd6864 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.2" +version = "1.5.3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/solve.jl b/src/solve.jl index 718da9d2..fd9b351b 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -16,14 +16,14 @@ end T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d) @inbounds return similar_type(b, T)( ((a[2,2]*a[3,3] - a[2,3]*a[3,2])*b[1] + - (a[1,3]*a[3,2] - a[1,2]*a[3,3])*b[2] + - (a[1,2]*a[2,3] - a[1,3]*a[2,2])*b[3]) / d, + (a[1,3]*a[3,2] - a[1,2]*a[3,3])*b[2] + + (a[1,2]*a[2,3] - a[1,3]*a[2,2])*b[3]) / d, ((a[2,3]*a[3,1] - a[2,1]*a[3,3])*b[1] + - (a[1,1]*a[3,3] - a[1,3]*a[3,1])*b[2] + - (a[1,3]*a[2,1] - a[1,1]*a[2,3])*b[3]) / d, + (a[1,1]*a[3,3] - a[1,3]*a[3,1])*b[2] + + (a[1,3]*a[2,1] - a[1,1]*a[2,3])*b[3]) / d, ((a[2,1]*a[3,2] - a[2,2]*a[3,1])*b[1] + - (a[1,2]*a[3,1] - a[1,1]*a[3,2])*b[2] + - (a[1,1]*a[2,2] - a[1,2]*a[2,1])*b[3]) / d ) + (a[1,2]*a[3,1] - a[1,1]*a[3,2])*b[2] + + (a[1,1]*a[2,2] - a[1,2]*a[2,1])*b[3]) / d ) end for Sa in [(2,2), (3,3)] # not needed for Sa = (1, 1); @@ -31,18 +31,25 @@ for Sa in [(2,2), (3,3)] # not needed for Sa = (1, 1); @inline function _solve(::Size{$Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {Sb, Ta, Tb} d = det(a) T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d) - c = similar(b, T) - for col = 1:Sb[2] - @inbounds c[:, col] = _solve(Size($Sa), Size($Sa[1],), a, b[:, col]) + if isbitstype(T) + # This if block can be removed when https://github.com/JuliaArrays/StaticArrays.jl/pull/749 is merged. + c = similar(b, T) + for col in 1:Sb[2] + @inbounds c[:, col] = _solve(Size($Sa), Size($Sa[1],), a, b[:, col]) + end + return similar_type(b, T)(c) + else + return _solve_general($(Size(Sa)), Size(Sb), a, b) end - return similar_type(b, T)(c) end end # @eval end +@inline function _solve(sa::Size, sb::Size, a::StaticMatrix, b::StaticVecOrMat) + _solve_general(sa, sb, a, b) +end - -@generated function _solve(::Size{Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb} +@generated function _solve_general(::Size{Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb} if Sa[1] != Sb[1] return quote throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)")) diff --git a/test/solve.jl b/test/solve.jl index 92d5d900..71e10225 100644 --- a/test/solve.jl +++ b/test/solve.jl @@ -3,7 +3,7 @@ using StaticArrays, Test, LinearAlgebra @testset "Solving linear system" begin @testset "Problem size: $n x $n. Matrix type: $m. Element type: $elty, Wrapper: $wrapper" for n in (1,2,3,4,5,8,15), (m, v) in ((SMatrix{n,n}, SVector{n}), (MMatrix{n,n}, MVector{n})), - elty in (Float64, Int), wrapper in (identity, Symmetric, Hermitian) + elty in (Float64, Int, BigFloat), wrapper in (identity, Symmetric, Hermitian) A = wrapper(elty.(rand(-99:2:99, n, n))) b = A * elty.(rand(2:5, n)) @@ -33,7 +33,7 @@ end @testset "Solving linear system (multiple RHS)" begin @testset "Problem size: $n x $n. Matrix type: $m1. Element type: $elty" for n in (1,2,3,4,5,8,15), (m1, m2) in ((SMatrix{n,n}, SMatrix{n,2}), (MMatrix{n,n}, MMatrix{n,2})), - elty in (Float64, Int) + elty in (Float64, Int, BigFloat) A = elty.(rand(-99:2:99, n, n)) b = A * elty.(rand(2:5, n, 2))