Skip to content

Commit

Permalink
Add specialization for Array arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Sep 27, 2023
1 parent 5b376f2 commit 0aca2f3
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 30 deletions.
37 changes: 26 additions & 11 deletions src/pdiagmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,6 @@ function quad(a::PDiagMat, x::AbstractVecOrMat)
return vec(sum(abs2.(x) .* a.diag; dims = 1))
end
end
function invquad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return invwsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) ./ a.diag; dims = 1))
end
end

function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
ad = a.diag
Expand All @@ -143,8 +133,18 @@ function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
r
end

function invquad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return invwsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) ./ a.diag; dims = 1))
end
end

function invquad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
m, n = size(x)
ad = a.diag
@check_argdims eachindex(ad) == axes(x, 1)
@check_argdims eachindex(r) == axes(x, 2)
Expand Down Expand Up @@ -184,3 +184,18 @@ function Xt_invA_X(a::PDiagMat, x::AbstractMatrix)
z = x ./ sqrt.(a.diag)
transpose(z) * z
end

### Specializations for `Array` arguments with reduced allocations

function quad(a::PDiagMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(zero(eltype(a)) * abs2(zero(eltype(x))))
return quad!(Vector{T}(undef, size(x, 2)), a, x)
end

function invquad(a::PDiagMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(abs2(zero(eltype(x))) / zero(eltype(a)))
return invquad!(Vector{T}(undef, size(x, 2)), a, x)
end

35 changes: 26 additions & 9 deletions src/pdmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,6 @@ function quad(a::PDMat, x::AbstractVecOrMat)
return vec(sum(abs2, aU_x; dims = 1))
end
end
function invquad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
inv_aL_x = chol_lower(cholesky(a)) \ x
if x isa AbstractVector
return sum(abs2, inv_aL_x)
else
return vec(sum(abs2, inv_aL_x; dims = 1))
end
end

function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
Expand All @@ -135,6 +126,17 @@ function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
end
return r
end

function invquad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
inv_aL_x = chol_lower(cholesky(a)) \ x
if x isa AbstractVector
return sum(abs2, inv_aL_x)
else
return vec(sum(abs2, inv_aL_x; dims = 1))
end
end

function invquad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
Expand Down Expand Up @@ -173,3 +175,18 @@ function Xt_invA_X(a::PDMat, x::AbstractMatrix)
z = chol_lower(a.chol) \ x
return transpose(z) * z
end

### Specializations for `Array` arguments with reduced allocations

function quad(a::PDMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(zero(eltype(a)) * abs2(zero(eltype(x))))
return quad!(Vector{T}(undef, size(x, 2)), a, x)
end

function invquad(a::PDMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(abs2(zero(eltype(x))) / zero(eltype(a)))
return invquad!(Vector{T}(undef, size(x, 2)), a, x)
end

11 changes: 6 additions & 5 deletions src/pdsparsemat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ function quad(a::PDSparseMat, x::AbstractVecOrMat)
z = sparse(chol_lower(cholesky(a)))' * x
return x isa AbstractVector ? sum(abs2, z) : vec(sum(abs2, z; dims = 1))
end
function invquad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
z = sparse(chol_lower(cholesky(a))) \ x
return x isa AbstractVector ? sum(abs2, z) : vec(sum(abs2, z; dims = 1))
end

function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
Expand All @@ -114,6 +109,12 @@ function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
return r
end

function invquad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
z = sparse(chol_lower(cholesky(a))) \ x
return x isa AbstractVector ? sum(abs2, z) : vec(sum(abs2, z; dims = 1))
end

function invquad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
for i in axes(x, 2)
Expand Down
12 changes: 7 additions & 5 deletions src/scalmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ function quad(a::ScalMat, x::AbstractVecOrMat)
return vec(sum(wsq, x; dims=1))
end
end

function quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
return map!(Base.Fix1(quad, a), r, eachcol(x))
end

function invquad(a::ScalMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
Expand All @@ -125,11 +132,6 @@ function invquad(a::ScalMat, x::AbstractVecOrMat)
end
end

function quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
return map!(Base.Fix1(quad, a), r, eachcol(x))
end
function invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
Expand Down

0 comments on commit 0aca2f3

Please sign in to comment.