Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support (un)whiten and (inv)quad with static arrays #183

Merged
merged 9 commits into from
Oct 13, 2023
54 changes: 24 additions & 30 deletions src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,6 @@ LinearAlgebra.checksquare(a::AbstractPDMat) = size(a, 1)

## whiten and unwhiten

whiten!(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(x, a, x)

function whiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
ldiv!(chol_lower(cholesky(a)), v)
end

function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix, x::AbstractVecOrMat)
v = _rcopy!(r, x)
lmul!(chol_lower(cholesky(a)), v)
end

"""
whiten(a::AbstractMatrix, x::AbstractVecOrMat)
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat)
Expand Down Expand Up @@ -80,35 +67,41 @@ julia> W * W'
0.0 1.0
```
"""
whiten(a::AbstractMatrix, x::AbstractVecOrMat) = whiten!(similar(x), a, x)
unwhiten(a::AbstractMatrix, x::AbstractVecOrMat) = unwhiten!(similar(x), a, x)
whiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten(AbstractPDMat(a), x)
unwhiten(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten(AbstractPDMat(a), x)
Comment on lines +70 to +71
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the restriction on the type parameter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, AbstractPDMats are Real matrices:

abstract type AbstractPDMat{T<:Real} <: AbstractMatrix{T} end


whiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = whiten!(x, a, x)
unwhiten!(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat) = unwhiten!(x, a, x)

function whiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return whiten!(r, AbstractPDMat(a), x)
end
function unwhiten!(r::AbstractVecOrMat, a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return unwhiten!(r, AbstractPDMat(a), x)
end

## quad

"""
quad(a::AbstractMatrix, x::AbstractVecOrMat)

Return the value of the quadratic form defined by `a` applied to `x`
Return the value of the quadratic form defined by `a` applied to `x`.

If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
function quad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
quad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
function quad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return quad(AbstractPDMat(a), x)
end

quad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_upper(cholesky(a)) * x)
invquad(a::AbstractMatrix, x::AbstractVector) = sum(abs2, chol_lower(cholesky(a)) \ x)

"""
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)

Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`
Overwrite `r` with the value of the quadratic form defined by `a` applied columnwise to `x`.
"""
quad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a * x)

function quad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix)
return quad!(r, AbstractPDMat(a), x)
end

"""
invquad(a::AbstractMatrix, x::AbstractVecOrMat)
Expand All @@ -120,15 +113,16 @@ For most `PDMat` types this is done in a way that does not require evaluation of
If `x` is a vector the quadratic form is `x' * a * x`. If `x` is a matrix
the quadratic form is applied column-wise.
"""
invquad(a::AbstractMatrix, x::AbstractVecOrMat) = x' / a * x
function invquad(a::AbstractMatrix{T}, x::AbstractMatrix{S}) where {T<:Real, S<:Real}
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
invquad!(Array{promote_type(T, S)}(undef, size(x,2)), a, x)
function invquad(a::AbstractMatrix{<:Real}, x::AbstractVecOrMat)
return invquad(AbstractPDMat(a), x)
end

"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix)

Overwrite `r` with the value of the quadratic form defined by `inv(a)` applied columnwise to `x`
"""
invquad!(r::AbstractArray, a::AbstractMatrix, x::AbstractMatrix) = colwise_dot!(r, x, a \ x)
function invquad!(r::AbstractArray, a::AbstractMatrix{<:Real}, x::AbstractMatrix)
return invquad!(r, AbstractPDMat(a), x)
end

82 changes: 50 additions & 32 deletions src/pdiagmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,45 +91,38 @@ LinearAlgebra.sqrt(a::PDiagMat) = PDiagMat(map(sqrt, a.diag))

### whiten and unwhiten

function whiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
n = a.dim
@check_argdims length(r) == length(x) == n
v = a.diag
for i = 1:n
r[i] = x[i] / sqrt(v[i])
end
return r
function whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
return r .= x ./ sqrt.(a.diag)
end

function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
n = a.dim
@check_argdims length(r) == length(x) == n
v = a.diag
for i = 1:n
r[i] = x[i] * sqrt(v[i])
end
return r
function unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
return r .= x .* sqrt.(a.diag)
end

function whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
r .= x ./ sqrt.(a.diag)
return r
function whiten(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return x ./ sqrt.(a.diag)
end

function unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
r .= x .* sqrt.(a.diag)
return r
function unwhiten(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return x .* sqrt.(a.diag)
end


whiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x ./ sqrt.(a.diag)
unwhiten!(r::AbstractVecOrMat, a::PDiagMat, x::AbstractVecOrMat) = r .= x .* sqrt.(a.diag)


### quadratic forms

quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x)
invquad(a::PDiagMat, x::AbstractVector) = invwsumsq(a.diag, x)
function quad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it is helpful/desirable here, but generally I tried to reduce the number of methods to lower the probability for method ambiguity errors. Maybe it's mostly useful for the generic code path in src/generics.jl.

return wsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) .* a.diag; dims = 1))
Comment on lines +121 to +123
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit unsatisfying - is there any way we could avoid unnecessary allocations but still make StaticArray return the expected types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a version with reduced allocations specialized for Array arguments.

end
end

function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
ad = a.diag
Expand All @@ -145,8 +138,18 @@ function quad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
r
end

function invquad(a::PDiagMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
if x isa AbstractVector
return invwsumsq(a.diag, x)
else
# map(Base.Fix1(invquad, a), eachcol(x)) or similar alternatives
# do NOT return a `SVector` for inputs `x::SMatrix`.
return vec(sum(abs2.(x) ./ a.diag; dims = 1))
end
end

function invquad!(r::AbstractArray, a::PDiagMat, x::AbstractMatrix)
m, n = size(x)
ad = a.diag
@check_argdims eachindex(ad) == axes(x, 1)
@check_argdims eachindex(r) == axes(x, 2)
Expand Down Expand Up @@ -186,3 +189,18 @@ function Xt_invA_X(a::PDiagMat, x::AbstractMatrix)
z = x ./ sqrt.(a.diag)
transpose(z) * z
end

### Specializations for `Array` arguments with reduced allocations

function quad(a::PDiagMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(zero(eltype(a)) * abs2(zero(eltype(x))))
return quad!(Vector{T}(undef, size(x, 2)), a, x)
end

function invquad(a::PDiagMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(abs2(zero(eltype(x))) / zero(eltype(a)))
return invquad!(Vector{T}(undef, size(x, 2)), a, x)
end

87 changes: 87 additions & 0 deletions src/pdmat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,78 @@ LinearAlgebra.eigmin(a::PDMat) = eigmin(a.mat)
Base.kron(A::PDMat, B::PDMat) = PDMat(kron(A.mat, B.mat), Cholesky(kron(A.chol.U, B.chol.U), 'U', A.chol.info))
LinearAlgebra.sqrt(A::PDMat) = PDMat(sqrt(Hermitian(A.mat)))

### (un)whitening

function whiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
v = _rcopy!(r, x)
return ldiv!(chol_lower(cholesky(a)), v)
end
function unwhiten!(r::AbstractVecOrMat, a::PDMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
v = _rcopy!(r, x)
return lmul!(chol_lower(cholesky(a)), v)
end

function whiten(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) \ x
end
function unwhiten(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) * x
end

## quad/invquad

function quad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
aU_x = chol_upper(cholesky(a)) * x
if x isa AbstractVector
return sum(abs2, aU_x)
else
return vec(sum(abs2, aU_x; dims = 1))
end
end

function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
aU = chol_upper(cholesky(a))
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
copyto!(z, view(x, :, i))
lmul!(aU, z)
r[i] = sum(abs2, z)
end
return r
end

function invquad(a::PDMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
inv_aL_x = chol_lower(cholesky(a)) \ x
if x isa AbstractVector
return sum(abs2, inv_aL_x)
else
return vec(sum(abs2, inv_aL_x; dims = 1))
end
end

function invquad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
aL = chol_lower(cholesky(a))
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
copyto!(z, view(x, :, i))
ldiv!(aL, z)
r[i] = sum(abs2, z)
end
return r
end

### tri products

function X_A_Xt(a::PDMat, x::AbstractMatrix)
Expand All @@ -111,3 +183,18 @@ function Xt_invA_X(a::PDMat, x::AbstractMatrix)
z = chol_lower(a.chol) \ x
return transpose(z) * z
end

### Specializations for `Array` arguments with reduced allocations

function quad(a::PDMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(zero(eltype(a)) * abs2(zero(eltype(x))))
return quad!(Vector{T}(undef, size(x, 2)), a, x)
end

function invquad(a::PDMat{<:Real,<:Vector}, x::Matrix)
@check_argdims a.dim == size(x, 1)
T = typeof(abs2(zero(eltype(x))) / zero(eltype(a)))
return invquad!(Vector{T}(undef, size(x, 2)), a, x)
end

65 changes: 56 additions & 9 deletions src/pdsparsemat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,37 +78,84 @@ LinearAlgebra.sqrt(A::PDSparseMat) = PDMat(sqrt(Hermitian(Matrix(A))))
### whiten and unwhiten

function whiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
# Can't use `ldiv!` due to missing support in SparseArrays
return copyto!(r, chol_lower(a.chol) \ x)
end

function unwhiten!(r::AbstractVecOrMat, a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims axes(r) == axes(x)
@check_argdims a.dim == size(x, 1)
# `*` is not defined for `PtL` factor components,
# so we can't use `chol_lower(a.chol) * x`
C = a.chol
PtL = sparse(C.L)[C.p, :]
# Can't use `lmul!` due to missing support in SparseArrays
return copyto!(r, PtL * x)
end

function whiten(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
return chol_lower(cholesky(a)) \ x
end

function unwhiten(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
# `*` is not defined for `PtL` factor components,
# so we can't use `chol_lower(a.chol) * x`
C = a.chol
PtL = sparse(C.L)[C.p, :]
return PtL * x
end

### quadratic forms

quad(a::PDSparseMat, x::AbstractVector) = dot(x, a * x)
invquad(a::PDSparseMat, x::AbstractVector) = dot(x, a \ x)
function quad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
# https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73
if VERSION < v"1.4.0-DEV.92"
z = a.mat * x
return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z))
else
return x isa AbstractVector ? dot(x, a.mat, x) : map(Base.Fix1(quad, a), eachcol(x))
end
end

function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
for i in axes(x, 2)
r[i] = quad(a, x[:,i])
@check_argdims axes(r) == axes(x, 2)
# https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73
if VERSION < v"1.4.0-DEV.92"
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
xi = view(x, :, i)
copyto!(z, xi)
lmul!(a.mat, z)
r[i] = dot(xi, z)
end
else
@inbounds for i in axes(x, 2)
xi = view(x, :, i)
r[i] = dot(xi, a.mat, xi)
end
end
return r
end

function invquad(a::PDSparseMat, x::AbstractVecOrMat)
@check_argdims a.dim == size(x, 1)
z = a.chol \ x
return x isa AbstractVector ? dot(x, z) : map(dot, eachcol(x), eachcol(z))
end

function invquad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
@check_argdims eachindex(r) == axes(x, 2)
for i in axes(x, 2)
r[i] = invquad(a, x[:,i])
@check_argdims axes(r) == axes(x, 2)
@check_argdims a.dim == size(x, 1)
z = similar(r, a.dim) # buffer to save allocations
@inbounds for i in axes(x, 2)
xi = view(x, :, i)
copyto!(z, xi)
ldiv!(a.chol, z)
r[i] = dot(xi, z)
end
return r
end
Expand Down
Loading
Loading