From 77b5d595eeee92499ce04366907aad63cdcabc80 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 2 Oct 2023 22:36:48 +0200 Subject: [PATCH] Support StaticArrays in X(t)_(inv)A_X(t) with ScalMat (#181) * Support StaticArrays in X(t)_(inv)A_X(t) with ScalMat * Add specializations for `Matrix` with reduced allocations --- src/scalmat.jl | 27 ++++++++++++++++++++++++--- test/specialarrays.jl | 5 ++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/scalmat.jl b/src/scalmat.jl index 6e8092e..ea672ec 100644 --- a/src/scalmat.jl +++ b/src/scalmat.jl @@ -99,20 +99,41 @@ invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsqinv!(r, function X_A_Xt(a::ScalMat, x::AbstractMatrix) @check_argdims LinearAlgebra.checksquare(a) == size(x, 2) - lmul!(a.value, x * transpose(x)) + a.value * (x * transpose(x)) end function Xt_A_X(a::ScalMat, x::AbstractMatrix) @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) - lmul!(a.value, transpose(x) * x) + a.value * (transpose(x) * x) end function X_invA_Xt(a::ScalMat, x::AbstractMatrix) @check_argdims LinearAlgebra.checksquare(a) == size(x, 2) - _rdiv!(x * transpose(x), a.value) + (x * transpose(x)) / a.value end function Xt_invA_X(a::ScalMat, x::AbstractMatrix) + @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) + (transpose(x) * x) / a.value +end + +# Specializations for `x::Matrix` with reduced allocations +function X_A_Xt(a::ScalMat, x::Matrix) + @check_argdims LinearAlgebra.checksquare(a) == size(x, 2) + lmul!(a.value, x * transpose(x)) +end + +function Xt_A_X(a::ScalMat, x::Matrix) + @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) + lmul!(a.value, transpose(x) * x) +end + +function X_invA_Xt(a::ScalMat, x::Matrix) + @check_argdims LinearAlgebra.checksquare(a) == size(x, 2) + _rdiv!(x * transpose(x), a.value) +end + +function Xt_invA_X(a::ScalMat, x::Matrix) @check_argdims LinearAlgebra.checksquare(a) == size(x, 1) _rdiv!(transpose(x) * x, a.value) end diff --git a/test/specialarrays.jl b/test/specialarrays.jl index 06621e8..b812a80 100644 --- a/test/specialarrays.jl +++ b/test/specialarrays.jl @@ -20,11 +20,14 @@ using StaticArrays @test D isa PDiagMat{Float64, <:SVector{4, Float64}} @test @inferred(kron(D, D)) isa PDiagMat{Float64, <:SVector{16, Float64}} + # Scaled identity matrix + E = ScalMat(4, 1.2) + x = @SVector rand(4) X = @SMatrix rand(10, 4) Y = @SMatrix rand(4, 10) - for A in (PDS, D) + for A in (PDS, D, E) @test A * x isa SVector{4, Float64} @test A * x ≈ Matrix(A) * Vector(x)