Skip to content

Commit

Permalink
Support (un)whiten and (inv)quad with static arrays (#183)
Browse files Browse the repository at this point in the history
* Support StaticArrays in X(t)_(inv)A_X(t) with ScalMat

* Add specializations for `Matrix` with reduced allocations

* Support `(un)whiten` and `(inv)quad` with static arrays

* Add specialization for `Array` arguments

* Fixes for PDSparseMat

* More fixes for PDSparseMat

* Fix tests
  • Loading branch information
devmotion authored Oct 13, 2023
1 parent 9572e79 commit 0c0f04c
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 77 deletions.
54 changes: 24 additions & 30 deletions src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,6 @@ LinearAlgebra.checksquare(a::AbstractPDMat) = size(a, 1)

## whiten and unwhiten

whiten!(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(x, a, x)

function whiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
ldiv!(chol_lower(cholesky(a)), v)
end

function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
lmul!(chol_lower(cholesky(a)), v)
end

"""
whiten(a::AbstractMatrix, x::AbstractVecOrMat)
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat)
Expand Down Expand Up @@ -80,35 +67,41 @@ julia> W * W'
0.0 1.0
```
"""
whiten(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(similar(x), a, x)
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(similar(x), a, x)
whiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten(AbstractPDMat(a), x)
unwhiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten(AbstractPDMat(a), x)

whiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten!(x, a, x)

function whiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return whiten!(r, AbstractPDMat(a), x)
end
function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return unwhiten!(r, AbstractPDMat(a), x)
end

## quad

"""
quad(a::AbstractMatrix, x::AbstractVecOrMat)
Return the value of the quadratic form defined by `a` applied to `x`
Return the value of the quadratic form defined by `a` applied to `x`.
If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
function quad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
quad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
function quad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return quad(AbstractPDMat(a), x)
end

quad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_upper(cholesky(a)) * x)
invquad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_lower(cholesky(a)) \ x)

"""
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)
Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`
Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`.
"""
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a * x)

function quad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix)
return quad!(r, AbstractPDMat(a), x)
end

"""
invquad(a::AbstractMatrix, x::AbstractVecOrMat)
Expand All @@ -120,15 +113,16 @@ For most `PDMat` types this is done in a way that does not require evaluation of
If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
invquad(a::AbstractMatrix, x::AbstractVecOrMat) = x' / a * x
function invquad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
invquad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
function invquad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return invquad(AbstractPDMat(a), x)
end

"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)
Overwrite `r` with the value of the quadratic form defined by `inv(a)` applied columnwise to `x`
"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a \ x)
function invquad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix)
return invquad!(r, AbstractPDMat(a), x)
end

82 changes: 50 additions & 32 deletions src/pdiagmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,45 +91,38 @@ LinearAlgebra.sqrt(a::PDiagMat) = PDiagMat(map(sqrt, a.diag))

### whiten and unwhiten

function whiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
n = a.dim
@check_argdims length(r) == length(x) == n
v = a.diag
for i = 1:n
r[i] = x[i] / sqrt(v[i])
end
return r
function whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
return r .= x ./ sqrt.(a.diag)
end

function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
n = a.dim
@check_argdims length(r) == length(x) == n
v = a.diag
for i = 1:n
r[i] = x[i] * sqrt(v[i])
end
return r
function unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
return r .= x .* sqrt.(a.diag)
end

function whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
r .= x ./ sqrt.(a.diag)
return r
function whiten(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return x ./ sqrt.(a.diag)
end

function unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
r .= x .* sqrt.(a.diag)
return r
function unwhiten(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return x .* sqrt.(a.diag)
end


whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x ./ sqrt.(a.diag)
unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x .* sqrt.(a.diag)


### quadratic forms

quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x)
invquad(a::PDiagMat, x::AbstractVector) = invwsumsq(a.diag, x)
function quad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return wsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) .* a.diag; dims = 1))
end
end

function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
ad = a.diag
Expand All @@ -145,8 +138,18 @@ function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
r
end

function invquad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return invwsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) ./ a.diag; dims = 1))
end
end

function invquad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
m, n = size(x)
ad = a.diag
@check_argdims eachindex(ad) == axes(x, 1)
@check_argdims eachindex(r) == axes(x, 2)
Expand Down Expand Up @@ -186,3 +189,18 @@ function Xt_invA_X(a::PDiagMat, x::AbstractMatrix)
z = x ./ sqrt.(a.diag)
transpose(z) * z
end

### Specializations for `Array` arguments with reduced allocations

function quad(a::PDiagMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(zero(eltype(a)) * abs2(zero(eltype(x))))
return quad!(Vector{T}(undef, size(x, 2)), a, x)
end

function invquad(a::PDiagMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(abs2(zero(eltype(x))) / zero(eltype(a)))
return invquad!(Vector{T}(undef, size(x, 2)), a, x)
end

87 changes: 87 additions & 0 deletions src/pdmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,78 @@ LinearAlgebra.eigmin(a::PDMat) = eigmin(a.mat)
Base.kron(A::PDMat, B::PDMat) = PDMat(kron(A.mat, B.mat), Cholesky(kron(A.chol.U, B.chol.U), 'U', A.chol.info))
LinearAlgebra.sqrt(A::PDMat) = PDMat(sqrt(Hermitian(A.mat)))

### (un)whitening

function whiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
v = _rcopy!(r, x)
return ldiv!(chol_lower(cholesky(a)), v)
end
function unwhiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
v = _rcopy!(r, x)
return lmul!(chol_lower(cholesky(a)), v)
end

function whiten(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) \ x
end
function unwhiten(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) * x
end

## quad/invquad

function quad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
aU_x = chol_upper(cholesky(a)) * x
if x isa AbstractVector
return sum(abs2, aU_x)
else
return vec(sum(abs2, aU_x; dims = 1))
end
end

function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
aU = chol_upper(cholesky(a))
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
copyto!(z, view(x, :, i))
lmul!(aU, z)
r[i] = sum(abs2, z)
end
return r
end

function invquad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
inv_aL_x = chol_lower(cholesky(a)) \ x
if x isa AbstractVector
return sum(abs2, inv_aL_x)
else
return vec(sum(abs2, inv_aL_x; dims = 1))
end
end

function invquad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
aL = chol_lower(cholesky(a))
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
copyto!(z, view(x, :, i))
ldiv!(aL, z)
r[i] = sum(abs2, z)
end
return r
end

### tri products

function X_A_Xt(a::PDMat, x::AbstractMatrix)
Expand All @@ -111,3 +183,18 @@ function Xt_invA_X(a::PDMat, x::AbstractMatrix)
z = chol_lower(a.chol) \ x
return transpose(z) * z
end

### Specializations for `Array` arguments with reduced allocations

function quad(a::PDMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(zero(eltype(a)) * abs2(zero(eltype(x))))
return quad!(Vector{T}(undef, size(x, 2)), a, x)
end

function invquad(a::PDMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(abs2(zero(eltype(x))) / zero(eltype(a)))
return invquad!(Vector{T}(undef, size(x, 2)), a, x)
end

65 changes: 56 additions & 9 deletions src/pdsparsemat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,37 +78,84 @@ LinearAlgebra.sqrt(A::PDSparseMat) = PDMat(sqrt(Hermitian(Matrix(A))))
### whiten and unwhiten

function whiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
# Can't use `ldiv!` due to missing support in SparseArrays
return copyto!(r, chol_lower(a.chol) \ x)
end

function unwhiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
# `*` is not defined for `PtL` factor components,
# so we can't use `chol_lower(a.chol) * x`
C = a.chol
PtL = sparse(C.L)[C.p, :]
# Can't use `lmul!` due to missing support in SparseArrays
return copyto!(r, PtL * x)
end

function whiten(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) \ x
end

function unwhiten(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
# `*` is not defined for `PtL` factor components,
# so we can't use `chol_lower(a.chol) * x`
C = a.chol
PtL = sparse(C.L)[C.p, :]
return PtL * x
end

### quadratic forms

quad(a::PDSparseMat, x::AbstractVector) = dot(x, a * x)
invquad(a::PDSparseMat, x::AbstractVector) = dot(x, a \ x)
function quad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
# https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73
if VERSION < v"1.4.0-DEV.92"
z = a.mat * x
return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z))
else
return x isa AbstractVector ? dot(x, a.mat, x) : map(Base.Fix1(quad, a), eachcol(x))
end
end

function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
for i in axes(x, 2)
r[i] = quad(a, x[:,i])
@check_argdims axes(r) == axes(x, 2)
# https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73
if VERSION < v"1.4.0-DEV.92"
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
xi = view(x, :, i)
copyto!(z, xi)
lmul!(a.mat, z)
r[i] = dot(xi, z)
end
else
@inbounds for i in axes(x, 2)
xi = view(x, :, i)
r[i] = dot(xi, a.mat, xi)
end
end
return r
end

function invquad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
z = a.chol \ x
return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z))
end

function invquad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
for i in axes(x, 2)
r[i] = invquad(a, x[:,i])
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
xi = view(x, :, i)
copyto!(z, xi)
ldiv!(a.chol, z)
r[i] = dot(xi, z)
end
return r
end
Expand Down
Loading

0 comments on commit 0c0f04c

Please sign in to comment.