Skip to content

Commit

Permalink
Support StaticArrays in X(t)_(inv)A_X(t) with ScalMat (#181)
Browse files Browse the repository at this point in the history
* Support StaticArrays in X(t)_(inv)A_X(t) with ScalMat

* Add specializations for `Matrix` with reduced allocations
  • Loading branch information
devmotion authored Oct 2, 2023
1 parent 5ca3316 commit 77b5d59
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
27 changes: 24 additions & 3 deletions src/scalmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,41 @@ invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix) = colwise_sumsqinv!(r,

function X_A_Xt(a::ScalMat, x::AbstractMatrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
lmul!(a.value, x * transpose(x))
a.value * (x * transpose(x))
end

function Xt_A_X(a::ScalMat, x::AbstractMatrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
lmul!(a.value, transpose(x) * x)
a.value * (transpose(x) * x)
end

function X_invA_Xt(a::ScalMat, x::AbstractMatrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
_rdiv!(x * transpose(x), a.value)
(x * transpose(x)) / a.value
end

function Xt_invA_X(a::ScalMat, x::AbstractMatrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
(transpose(x) * x) / a.value
end

# Specializations for `x::Matrix` with reduced allocations
function X_A_Xt(a::ScalMat, x::Matrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
lmul!(a.value, x * transpose(x))
end

function Xt_A_X(a::ScalMat, x::Matrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
lmul!(a.value, transpose(x) * x)
end

function X_invA_Xt(a::ScalMat, x::Matrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
_rdiv!(x * transpose(x), a.value)
end

function Xt_invA_X(a::ScalMat, x::Matrix)
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
_rdiv!(transpose(x) * x, a.value)
end
5 changes: 4 additions & 1 deletion test/specialarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ using StaticArrays
@test D isa PDiagMat{Float64, <:SVector{4, Float64}}
@test @inferred(kron(D, D)) isa PDiagMat{Float64, <:SVector{16, Float64}}

# Scaled identity matrix
E = ScalMat(4, 1.2)

x = @SVector rand(4)
X = @SMatrix rand(10, 4)
Y = @SMatrix rand(4, 10)

for A in (PDS, D)
for A in (PDS, D, E)
@test A * x isa SVector{4, Float64}
@test A * x Matrix(A) * Vector(x)

Expand Down

0 comments on commit 77b5d59

Please sign in to comment.