Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MomentMatrixWeightSolver option #27

Merged
merged 4 commits into from
Nov 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ julia = "1"

[extras]
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DynamicPolynomials", "Test"]
test = ["DynamicPolynomials", "Random", "Test"]
10 changes: 10 additions & 0 deletions docs/src/atoms.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@ ShiftChol
SVDChol
MultivariateMoments.lowrankchol
```

Once the center of the atoms are determined, a linear system is solved to determine
the weights corresponding to each dirac.
By default, [`MomentMatrixWeightSolver`](@ref) is used by [`extractatoms`](@ref) so that if there are small differences between moment values corresponding to the same monomial in the matrix
(which can happen if these moments were computed numerically by a semidefinite proramming solvers, e.g., with [SumOfSquares](https://github.com/jump-dev/SumOfSquares.jl)),
the linear system handles that automatically.
```@docs
MomentMatrixWeightSolver
MomentVectorWeightSolver
```
88 changes: 78 additions & 10 deletions src/extract.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export extractatoms
export LowRankChol, ShiftChol, SVDChol
export MomentMatrixWeightSolver, MomentVectorWeightSolver

using RowEchelon
using SemialgebraicSets
Expand Down Expand Up @@ -166,6 +167,80 @@ function computesupport!(μ::MomentMatrix, ranktol::Real, args...)
return computesupport!(μ::MomentMatrix, ranktol::Real, SVDChol(), args...)
end

# Determines weight

"""
struct MomentMatrixWeightSolver
rtol::T
atol::T
end

Given a moment matrix `ν` and the atom centers,
determine the weights by solving a linear system over all the moments
of the moment matrix, keeping duplicates (e.g., entries corresponding to the same monomial).

If the moment values corresponding to the same monomials are known to be equal
prefer [`MomentVectorWeightSolver`](@ref) instead.
"""
struct MomentMatrixWeightSolver
end

function solve_weight(ν::MomentMatrix{T}, centers, solver::MomentMatrixWeightSolver) where {T}
vars = variables(ν)
A = Matrix{T}(undef, length(ν.Q.Q), length(centers))
vbasis = vectorized_basis(ν)
for i in eachindex(centers)
η = dirac(vbasis.monomials, vars => centers[i])
A[:, i] = moment_matrix(η, ν.basis.monomials).Q.Q
end
return A \ ν.Q.Q
end

"""
struct MomentVectorWeightSolver{T}
rtol::T
atol::T
end

Given a moment matrix `ν` and the atom centers, first convert the moment matrix
to a vector of moments, using [`measure(ν; rtol=rtol, atol=atol)`](@ref measure)
and then determine the weights by solving a linear system over the monomials obtained.

If the moment values corresponding to the same monomials can have small differences,
[`measure`](@ref) can throw an error if `rtol` and `atol` are not small enough.
Alternatively to tuning these tolerances [`MomentVectorWeightSolver`](@ref) can be used instead.
"""
struct MomentVectorWeightSolver{T}
rtol::T
atol::T
end
function MomentVectorWeightSolver{T}(; rtol=Base.rtoldefault(T), atol=zero(T)) where {T}
return MomentVectorWeightSolver{T}(rtol, atol)
end
function MomentVectorWeightSolver(; rtol=nothing, atol=nothing)
if rtol === nothing && atol === nothing
return MomentVectorWeightSolver{Float64}()
elseif rtol !== nothing
if atol === nothing
return MomentVectorWeightSolver{typeof(rtol)}(; rtol=rtol)
else
return MomentVectorWeightSolver{typeof(rtol)}(; rtol=rtol, atol=atol)
end
else
return MomentVectorWeightSolver{typeof(atol)}(; atol=atol)
end
end

function solve_weight(ν::MomentMatrix{T}, centers, solver::MomentVectorWeightSolver) where {T}
μ = measure(ν; rtol=solver.rtol, atol=solver.atol)
vars = variables(μ)
A = Matrix{T}(undef, length(μ.x), length(centers))
for i in eachindex(centers)
A[:, i] = dirac(μ.x, vars => centers[i]).a
end
return A \ μ.a
end

"""
extractatoms(ν::MomentMatrix, ranktol, [dec::LowRankChol], [solver::SemialgebraicSets.AbstractAlgebraicSolver])

Expand All @@ -188,28 +263,21 @@ then the Schur decomposition of a random combination of these matrices.
For floating point arithmetics, homotopy continuation is recommended as it is
more numerically stable than Gröbner basis computation.
"""
function extractatoms(ν::MomentMatrix{T}, ranktol, args...) where T
function extractatoms(ν::MomentMatrix{T}, ranktol, args...; weight_solver = MomentMatrixWeightSolver()) where T
computesupport!(ν, ranktol, args...)
supp = ν.support
if !iszerodimensional(supp)
return nothing
end
centers = collect(supp)
r = length(centers)
# Determine weights
μ = measure(ν)
vars = variables(μ)
A = Matrix{T}(undef, length(μ.x), r)
for i in 1:r
A[:, i] = dirac(μ.x, vars => centers[i]).a
end
weights = A \ μ.a
weights = solve_weight(ν, centers, weight_solver)
isf = isfinite.(weights)
weights = weights[isf]
centers = centers[isf]
if isempty(centers)
nothing
else
AtomicMeasure(vars, WeightedDiracMeasure.(centers, weights))
AtomicMeasure(variables(ν), WeightedDiracMeasure.(centers, weights))
end
end
11 changes: 7 additions & 4 deletions src/measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct Measure{T, MT <: AbstractMonomial, MVT <: AbstractVector{MT}} <: Abstract
new(a, x)
end
end
function Measure(a::AbstractVector{T}, x::AbstractVector{TT}) where {T, TT <: AbstractTermLike}
function Measure(a::AbstractVector{T}, x::AbstractVector{TT}; kws...) where {T, TT <: AbstractTermLike}
# cannot use `monovec(a, x)` as it would sum the entries
# corresponding to the same monomial.
if length(a) != length(x)
Expand All @@ -24,7 +24,7 @@ function Measure(a::AbstractVector{T}, x::AbstractVector{TT}) where {T, TT <: Ab
for i in eachindex(x)
j = rev[x[i]]
if i != σ[j]
if !isapprox(b[j], a[i])
if !isapprox(b[j], a[i]; kws...)
error("The monomial `$(x[i])` is occurs twice with different values: `$(a[i])` and `$(b[j])`")
end
end
Expand All @@ -34,11 +34,14 @@ function Measure(a::AbstractVector{T}, x::AbstractVector{TT}) where {T, TT <: Ab
end

"""
measure(a, X::AbstractVector{<:AbstractMonomial})
measure(a::AbstractVector{T}, X::AbstractVector{<:AbstractMonomial}; rtol=Base.rtoldefault(T), atol=zero(T))

Creates a measure with moments `moment(a[i], X[i])` for each `i`.
An error is thrown if there exists `i` and `j` such that `X[i] == X[j]` but
`!isapprox(a[i], a[j]; rtol=rtol, atol=atol)`.
"""
measure(a, X) = Measure(a, X)
measure(a, X; kws...) = Measure(a, X; kws...)
measure(a, basis::MB.MonomialBasis; kws...) = measure(a, basis.monomials; kws...)

"""
variables(μ::AbstractMeasureLike)
Expand Down
11 changes: 8 additions & 3 deletions src/moment_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,15 @@ moment_matrix(Q::AbstractMatrix, monos) = MomentMatrix(Q, monos)

getmat(μ::MomentMatrix) = Matrix(μ.Q)

function measure(ν::MomentMatrix{T, <:MB.MonomialBasis, SymMatrix{T}}) where T
n = length(ν.basis)
function vectorized_basis(ν::MomentMatrix{T,<:MB.MonomialBasis}) where {T}
monos = ν.basis.monomials
measure(ν.Q.Q, [monos[i] * monos[j] for i in 1:n for j in 1:i])
n = length(monos)
return MB.MonomialBasis([monos[i] * monos[j] for i in 1:n for j in 1:i])
end

function measure(ν::MomentMatrix; kws...) where T
n = length(ν.basis)
measure(ν.Q.Q, vectorized_basis(ν); kws...)
end

struct SparseMomentMatrix{T, B <: MB.AbstractPolynomialBasis, MT} <: AbstractMomentMatrix{T, B}
Expand Down
13 changes: 12 additions & 1 deletion test/moment_matrix.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test
using Random

struct DummySolver <: SemialgebraicSets.AbstractAlgebraicSolver end
function SemialgebraicSets.solvealgebraicequations(
Expand Down Expand Up @@ -52,6 +53,16 @@ const DEFAULT_SOLVER = SemialgebraicSets.defaultalgebraicsolver([1.0x - 1.0x])
atoms = extractatoms(ν, 1e-4, lrc, DEFAULT_SOLVER)
@test atoms !== nothing
@test atoms ≈ η
if !(lrc isa ShiftChol) # the shift `1e-14` is too small compared to the noise of `1e-6`. We want high noise so that the default rtol of `Base.rtoldefault` does not work so that it tests that `rtol` is passed around.
Random.seed!(0)
ν2 = MomentMatrix(SymMatrix(ν.Q.Q + rand(length(ν.Q.Q)) * 1e-6, ν.Q.n), ν.basis)
@test_throws ErrorException extractatoms(ν2, 1e-4, lrc, DEFAULT_SOLVER, weight_solver = MomentVectorWeightSolver())
for solver in [MomentMatrixWeightSolver(), MomentVectorWeightSolver(rtol=1e-5), MomentVectorWeightSolver(atol=1e-5)]
atoms = extractatoms(ν2, 1e-4, lrc, DEFAULT_SOLVER, weight_solver = solver)
@test atoms !== nothing
@test atoms ≈ η rtol=1e-4
end
end
end
end

Expand Down Expand Up @@ -144,7 +155,7 @@ end
# With 1e-6, the rank is detected to be 3
# With 1e-7, the rank is detected to be 5
# With 1e-8, the rank is detected to be 6
atoms = extractatoms(ν, ranktol)
atoms = extractatoms(ν, ranktol, weight_solver=MomentVectorWeightSolver())
@test atoms !== nothing
@test atoms ≈ η
end
Expand Down