diff --git a/Project.toml b/Project.toml index c5a75472..4edbbfd9 100644 --- a/Project.toml +++ b/Project.toml @@ -4,9 +4,11 @@ authors = ["Chad Scherrer ", "Oliver Schulz tpl[from + i - 1], Val(until - from + 1)) +end +# ToDo: Is this specialization necessary? +Base.Base.@propagate_inbounds function _get_or_view(tpl::Tuple, ::StaticInteger{from}, ::StaticInteger{until}) where {from,until} + ntuple(i -> tpl[from + i - 1], Val(until - from + 1)) +end + + +@inline function _split_after(x::AbstractVector, n::IntegerLike) + idxs = maybestatic_eachindex(x) + i_first = maybestatic_first(idxs) + i_last = maybestatic_last(idxs) + _get_or_view(x, i_first, i_first + n - one(n)), _get_or_view(x, i_first + n, i_last) +end + +@inline _split_after(x::Tuple, n) = _split_after(x::Tuple, Val{n}()) +@inline _split_after(x::Tuple, ::Val{N}) where N = x[begin:begin+N-1], x[begin+N:end] + +@generated function _split_after(x::NamedTuple{names}, ::Val{names_a}) where {names, names_a} + n = length(names_a) + if names[begin:begin+n-1] == names_a + names_b = names[begin+n:end] + quote + a, b = _split_after(values(x), Val($n)) + NamedTuple{$names_a}(a), NamedTuple{$names_b}(b) + end + else + quote + throw(ArgumentError("Can't split NamedTuple{$names} after {$names_a}")) + end + end +end + + +_empty_zero(::AbstractVector{T}) where {T<:Real} = Fill(zero(T), 0) + + +#= +struct _TupleNamer{names} <: Function end +struct _TupleUnNamer{names} <: Function end + +(::TupleNamer{names})(x::Tuple) where names = NamedTuple{names}(x) +InverseFunctions.inverse(::TupleNamer{names}) where names = TupleUnNamer{names}() +ChangesOfVariables.with_logabsdet_jacobian(::TupleNamer{names}, x::Tuple) where names = static(false) + +(::TupleUnNamer{names})(x::NamedTuple{names}) where {names} = values(x) +InverseFunctions.inverse(::TupleUnNamer{names}) where names = TupleNamer{names}() +ChangesOfVariables.with_logabsdet_jacobian(::TupleUnNamer{names}, x::NamedTuple{names}) where names = static(false) +=# + +# Field access functions for Fill: +_fill_value(x::FillArrays.Fill) = x.value +_fill_axes(x::FillArrays.Fill) = x.axes + + +_flatten_to_rv(VV::AbstractVector{<:AbstractVector{<:Real}}) = flatview(VectorOfArrays(VV)) +_flatten_to_rv(VV::AbstractVector{<:StaticVector{N,<:Real}}) where N = flatview(VectorOfSimilarArrays(VV)) + +_flatten_to_rv(VV::VectorOfSimilarVectors{<:Real}) = flatview(VV) +_flatten_to_rv(VV::VectorOfVectors{<:Real}) = flatview(VV) + +_flatten_to_rv(::Tuple{}) = [] +_flatten_to_rv(tpl::Tuple{Vararg{AbstractVector}}) = vcat(tpl...) +_flatten_to_rv(tpl::Tuple{Vararg{StaticVector}}) = vcat(tpl...) diff --git a/src/combinators/abstract_product.jl b/src/combinators/abstract_product.jl new file mode 100644 index 00000000..2d14cbce --- /dev/null +++ b/src/combinators/abstract_product.jl @@ -0,0 +1,42 @@ +""" + marginals(μ::AbstractMeasure) + +Returns the marginals measures of `μ` as a collection of measures. + +The kind of marginalization implied by `marginals` depends on the +type of `μ`. + +`μ` may be a power of a measure or a product of measures, but other +types of measures may support `marginals` as well. +""" +function marginals end +export marginals + + +""" + abstract type AbstractProductMeasure + +Abstact type for products of measures. + +[`marginals(μ::AbstractProductMeasure)`](@ref) returns the collection of +measures that `μ` is the product of. +""" +abstract type AbstractProductMeasure <: AbstractMeasure end +export AbstractProductMeasure + +function Pretty.tile(μ::AbstractProductMeasure) + result = Pretty.literal("ProductMeasure(") + result *= Pretty.tile(marginals(μ)) + result *= Pretty.literal(")") +end + +massof(m::AbstractProductMeasure) = prod(massof, marginals(m)) + +Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure) = marginals(a) == marginals(b) +Base.isapprox(a::AbstractProductMeasure, b::AbstractProductMeasure; kwargs...) = isapprox(marginals(a), marginals(b); kwargs...) + + +# # ToDo: Do we want this? It's not so clear what the semantics of `length` and `size` +# # for measures should be, in general: +# Base.length(μ::AbstractProductMeasure) = length(marginals(μ)) +# Base.size(μ::AbstractProductMeasure) = size(marginals(μ)) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index cc2022f2..fe6af2bf 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,36 +1,332 @@ -struct Bind{M,K} <: AbstractMeasure - μ::M - k::K -end +@doc raw""" + mkernel(f_β, f_c = OneTwoMany.secondarg)::Function + +Constructs generalized monadic transistion kernel from a primary transition +kernel function `f_β` and a value combination function `f_c`. + +`f_β` must behave like `β = f_β(a)`, taking a value `a` from a primary +measurable space and return a measure-like object `β`. + +`f_c` must behave like `c = f_c(a, b)`, taking a value `a` (like f_β) and a +value `b` from the measurable space of `β` and return a value `c`. + +`f_k = mkernel(f_β, f_c)` then acts like + +```julia +f_k(a) ≡ pushforward(c -> f_c(c[1], c[2]), productmeasure((Dirac(a), f_β(a)))) +``` + +(`≡` denoting pseudocode-equivalency here). So with the default +`f_c == OneTwoMany.secondarg`, we just have `f_k(a) ≡ f_β(a) + +Also, + +```julia +mbind(mkernel(f_β, f_c), α) == mbind(f_β, α, f_c) +``` + +See also [`mbind`](@ref). +""" +function mkernel end +export mkernel + + +""" + struct MeasureBase.MKernel <: Function -export ↣ +Represents a generalized monatic transition kernel. +User code should not create instances of `MKernel` directly, but should call +[`mkernel`](@ref) instead. """ -If -- μ is an `AbstractMeasure` or satisfies the Measure interface, and -- k is a function taking values from the support of μ and returning a measure +struct MKernel{FT,FC} <: Function + f_β::FT + f_c::FC +end + + +@inline mkernel(f_β::MKernel) = f_β +@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c) + +@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c) +@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β + + + +@doc raw""" + mbind(f_β, α::AbstractMeasure, f_c = OneTwoMany.secondarg) + mbind(f_β::MeasureBase.MKernel, α::AbstractMeasure) -Then `μ ↣ k` is a measure, called a *monadic bind*. In a -probabilistic programming language like Soss.jl, this could be expressed as +Constructs a monadic bind, resp. a hierarchical measure, from a transition +kernel function `f_β`, a primary measure `α` and a value combination +function `f_c`. + +`f_β` must be a function that maps a point `a` from the space of the primary +measure `α` to a dependent secondary measure `β_a = f_β(a)`. +`ab = f_c(a, b)` must map such a point `a` and a point `b` from the +space of measure `β_a` to a combined value `ab = f_c(a, b)`. + +The resulting measure + +```julia +μ = mbind(f_β, α, f_c) +``` + +has the mathethematical interpretation (on sets $$A$$ and $$B$$) + +```math +\mu(f_c(A, B)) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) +``` -Note that bind is usually written `>>=`, but this symbol is unavailable in Julia. +When using the default `fc = OneTwoMany.secondarg` (so `ab == b`) this simplies to +```math +\mu(B) = \int_A \beta_a(B)\, \mathrm{d}\, \alpha(a) ``` -bind = @model μ,k begin - x ~ μ - y ~ k(x) - return y + +which is equivalent to a monadic bind, viewing measures as monads. + +Computationally, `ab = rand(μ)` is equivalent to + +```julia +a = rand(μ_primary) +β_a = f_β(a) +b = rand(β_a) +ab = f_c(a, b) +``` + +The measure `α` that went into the bind can be retrieved via +`boundmeasure(mbind(f_β, α, ...)) == α`. + +`mbind(f_β, α, f_c)` is equivalent to `mbind(mkernel(f_β, f_c), α)` +(see [`mkernel`](@ref)) with +`bindkernel(mbind(mkernel(f_β, f_c), α)) == mbind(mkernel(f_β, f_c)`. + +Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` +can be unambiguously split into `a` and `b` again, knowing `α`. This is +currently implemented for `f_c` that is either `tuple` or `=>`/`Pair` (these +work for any combination of variate types), `vcat` (for tuple- or vector-like +variates) and `merge` (`NamedTuple` variates). +[`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to +support other choices for `f_c`. + +# Extended help + +Bayesian example with a correlated prior: Mathematically, let + +position = a1 ~ StdNormal(), +noise = a2 ~ pushforward(h(a1, .), StdExponential()) + +where `h(a1,a2) = √(abs(a1) * a2)`. +Because this prior on the space of `A = A1 × A2 = (position, noise)` is a +hierarchical measure (a2 depends on a1), we can construct it using mbind by +setting merge as f_c: + +```julia +using MeasureBase, AffineMaps + +prior = mbind( + productmeasure(( + position = StdNormal(), + )), merge +) do a + productmeasure(( + noise = pushfwd(setinverse(sqrt, setladj(x -> x^2, x -> log(2))) ∘ Mul(abs(a.position)), StdExponential()), + )) end + +model = θ -> pushfwd(MulAdd(θ.noise, θ.position), StdNormal())^10 + +joint_θ_obs = mbind(model, prior, tuple) +prior_predictive = mbind(model, prior) + +observation = rand(prior_predictive) +likelihood = likelihoodof(model, observation) + +posterior = mintegrate(likelihood, prior) + +θ = rand(prior) +logdensityof(posterior, θ) ``` +""" +function mbind end +export mbind + +@inline mbind(f_β) = Base.Fix1(mbind, f_β) + +@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, asmeasure(α), f_c) + +@inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c) + F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) + Bind{F,M,G}(f_β, α, f_c) +end + +@inline _generic_mbind_impl(f_β, α::Dirac, f_c) = mcombine(f_c, α, f_β(α.x)) + +@inline _generic_mbind_impl(@nospecialize(f_β), α::AbstractMeasure, ::typeof(firstarg)) = α +@inline _generic_mbind_impl(@nospecialize(f_β), α::Dirac, ::typeof(firstarg)) = α + +@inline _generic_mbind_impl(f_k::MKernel, α::AbstractMeasure, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c) +@inline _generic_mbind_impl(f_k::MKernel, α::Dirac, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c) + -See also `bind` and `Bind` """ -↣(μ, k) = bind(μ, k) + struct MeasureBase.Bind <: AbstractMeasure + +Represents a monatic bind resp. a mbind in general. + +User code should not create instances of `Bind` directly, but should call +[`mbind`](@ref) instead. +""" +struct Bind{FT,M<:AbstractMeasure,FC} <: AbstractMeasure + f_β::FT + α::M + f_c::FC +end + +# ToDo: Store MKernel in Bind instead of separate fields f_β and f_c? + + +""" + bindkernel(μ::Bind)::MKernel + +Returns the monatic transition kernel of a monatic bind, so that +`bindkernel(mbind(f_k::MKernel, α)) == f_k`. + +See [`mbind`](@ref) and [`mkernel`](@ref) for details. +""" +function bindkernel end +export bindkernel + +bindkernel(μ::Bind) = mkernel(μ.f_β, μ.f_c) + + +""" + boundmeasure(μ::Bind)::MKernel + +Returns the measure that went into a monatic bind, so that +`boundmeasure(mbind(f_k, α)) == α`. + +See [`mbind`](@ref) and [`mkernel`](@ref) for details. +""" +function boundmeasure end +export boundmeasure + +boundmeasure(μ::Bind) = mkernel(μ.f_β, μ.f_c) + + +""" + MeasureBase.transportmeasure(μ::Bind, x)::AbstractMeasure + +Evaluates a monatic bind `μ` at a point `x`. + +The resulting measure behaves like `μ` in the infinitesimal neighborhood +of `x` in respect to density calculation and transport as well. +""" +function transportmeasure(μ::Bind, x) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, x) + tpm_β_a = transportmeasure(_get_β_a(μ, a), b) + mcombine(μ.f_c, tpm_α, tpm_β_a) +end + +localmeasure(μ::Bind, x) = transportmeasure(μ, x) + +_get_β_a(μ::Bind, a) = asmeasure(μ.f_β(a)) + +tpmeasure_split_combined(f_c, μ::Bind, xy) = _bind_tpm_sc(f_c, μ::Bind, xy) + +function _bind_tpm_sc(::typeof(tuple), μ::Bind, xy::Tuple{Vararg{Any,2}}) + x, y = x[1], y[1] + tpm_μ = transportmeasure(μ, x) + return tpm_μ, x, y +end + +function _bind_tpm_sc(::Type{Pair}, μ::Bind, xy::Pair) + x, y = x.first, y.second + tpm_μ = transportmeasure(μ, x) + return tpm_μ, x, y +end + +const _BindBy{FC} = Bind{<:Any,<:AbstractMeasure,FC} +_bind_tpm_sc(f_c::typeof(vcat), μ::_BindBy{typeof(vcat)}, xy::AbstractVector) = _bind_tpm_sc_cat(f_c, μ, xy) +_bind_tpm_sc(f_c::typeof(merge), μ::_BindBy{typeof(merge)}, xy::NamedTuple) = _bind_tpm_sc_cat(f_c, μ, xy) + +function _bind_tpm_sc_cat_lμabyxy(f_c, μ, xy) + tpm_α, a, by = tpmeasure_split_combined(μ.f_c, μ.α, xy) + β_a = _get_β_a(μ, a) + tpm_β_a, b, y = tpmeasure_split_combined(f_c, β_a, by) + tpm_μ = mcombine(μ.f_c, tpm_α, tpm_β_a) + return tpm_μ, a, b, y, xy +end + +function _bind_tpm_sc_cat(f_c::typeof(vcat), μ::_BindBy{typeof(vcat)}, xy::AbstractVector) + tpm_μ, a, b, y, xy = _bind_tpm_sc_cat_lμabyxy(f_c, μ, xy) + # Don't use `x = f_c(a, b)` here, would allocate, splitting xy can use views: + x, y = _split_after(xy, length(a) + length(b)) + return tpm_μ, x, y +end + +function _bind_tpm_sc_cat(f_c::typeof(merge), μ::_BindBy{typeof(merge)}, xy::NamedTuple) + tpm_μ, a, b, y, xy = _bind_tpm_sc_cat_lμabyxy(f_c, μ, xy) + return tpm_μ, f_c(a, b), y +end + + +@inline insupport(μ::Bind, ::Any) = NoFastInsupport{typeof(μ)}() + +@inline getdof(μ::Bind) = NoDOF{typeof(μ)}() + +# Bypass `checked_arg`, would require potentially costly evaluation of h.f: +@inline checked_arg(::Bind, x) = x + +rootmeasure(::Bind) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for Bind")) + +basemeasure(::Bind) = throw(ArgumentError("basemeasure is not available for Bind")) + +testvalue(::Bind) = throw(ArgumentError("testvalue is not available for Bind")) + +logdensity_def(::Bind, x) = throw(ArgumentError("logdensity_def is not available for Bind")) + +# Specialize logdensityof to avoid duplicate calculations: +function logdensityof(μ::Bind, x) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, x) + β_a = _get_β_a(μ, a) + logdensityof(tpm_α, a) + logdensityof(β_a, b) +end + +# Specialize unsafe_logdensityof to avoid duplicate calculations: +function unsafe_logdensityof(μ::Bind, x) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, x) + β_a = _get_β_a(μ, a) + unsafe_logdensityof(tpm_α, a) + unsafe_logdensityof(β_a, b) +end + + +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::Bind) where {T<:Real} + a = rand(rng, T, μ.α) + b = rand(rng, T, _get_β_a(μ, a)) + return μ.f_c(a, b) +end + +function Base.rand(rng::Random.AbstractRNG, μ::Bind) + a = rand(rng, μ.α) + b = rand(rng, _get_β_a(μ, a)) + return μ.f_c(a, b) +end + + +function transport_to_mvstd(ν_inner::StdMeasure, μ::Bind, ab) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + β_a = _get_β_a(μ, a) + y1 = transport_to_mvstd(ν_inner, tpm_α, a) + y2 = transport_to_mvstd(ν_inner, β_a, b) + return vcat(y1, y2) +end -bind(μ, k) = Bind(μ, k) -function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} - x = rand(rng, T, d.μ) - y = rand(rng, T, d.k(x)) - return y +function transport_from_mvstd_with_rest(ν::Bind, μ_inner::StdMeasure, x) + a, x2 = transport_from_mvstd_with_rest(ν.α, μ_inner, x) + β_a = ν.f_β(a) + b, x_rest = transport_from_mvstd_with_rest(β_a, μ_inner, x2) + return ν.f_c(a, b), x_rest end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl new file mode 100644 index 00000000..4127e952 --- /dev/null +++ b/src/combinators/combined.jl @@ -0,0 +1,155 @@ +""" + MeasureBase.tpmeasure_split_combined(f_c, α::AbstractMeasure, ab) + +Splits a combined value `ab` that originated from combining a point `a` +from the space of a measure `α` with a point `b` from the space of +another measure `β` via `ab = f_c(a, b)`. + +Returns a semantic equivalent of +`(MeasureBase.transportmeasure(α, a), a, b)`. + +With `a_orig = rand(α)`, `b_orig = rand(β)` and +`ab = f_c(a_orig, b_orig)`, the following must hold true: + +```julia +tpm_α, a, b = tpmeasure_split_combined(f_c, α, ab) +a ≈ a_orig && b ≈ b_orig +``` +""" +function tpmeasure_split_combined end + +function tpmeasure_split_combined(f_c, α::AbstractMeasure, ab) + a, b = _generic_split_combined(f_c, α, ab) + return transportmeasure(α, a), a, b +end + +@inline _generic_split_combined(::typeof(tuple), ::AbstractMeasure, x::Tuple{Vararg{Any,2}}) = x +@inline _generic_split_combined(::Type{Pair}, ::AbstractMeasure, ab::Pair) = (ab...,) + +function _generic_split_combined(f_c::FC, α::AbstractMeasure, ab) where FC + _split_variate_byvalue(f_c, testvalue(α), ab) +end + +_split_variate_byvalue(::typeof(vcat), test_a::AbstractVector, ab::AbstractVector) = _split_after(ab, length(test_a)) + +_split_variate_byvalue(::typeof(vcat), ::Tuple{N}, ab::Tuple) where N = _split_after(ab, Val{N}()) + +function _split_variate_byvalue(::typeof(merge), ::NamedTuple{names_a}, ab::NamedTuple) where names_a + _split_after(ab, Val(names_a)) +end + + +@doc raw""" + mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) + +Combines two measures `α` and `β` to a combined measure via a point combination +function `f_c`. + +`f_c` must combine a given point `a` from the space of measure `α` with a +given point `b` from the space of measure `β` to a single value +`ab = f_c(a, b)` in the space of the combined measure +`μ = mcombine(f_c, α, β)`. + +The combined measure has the mathethematical interpretation (on +sets $$A$$ and $$B$$) + +```math +\mu(f_c(A, B)) = \alpha(A)\, \beta(B) +``` +""" +function mcombine end +export mcombine + +@inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage1(f_c, α, β) + +@inline _generic_mcombine_impl_stage1(::typeof(firstarg), α::AbstractMeasure, β::AbstractMeasure) = α +@inline _generic_mcombine_impl_stage1(::typeof(secondarg), α::AbstractMeasure, β::AbstractMeasure) = β + +@inline function _generic_mcombine_impl_stage1(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) + productmeasure((α, β)) +end + +@inline function _generic_mcombine_impl_stage1(f_c::Union{typeof(vcat),typeof(merge)}, α::AbstractProductMeasure, β::AbstractProductMeasure) + productmeasure(f_c(marginals(α), marginals(β))) +end + +@inline function _generic_mcombine_impl_stage1(f_c, α::AbstractMeasure, β::AbstractMeasure) + _generic_mcombine_impl_stage2(f_c, α, β) +end + +@inline function _generic_mcombine_impl_stage2(f_c, α::AbstractMeasure, β::AbstractMeasure) + FC, MA, MB = Core.Typeof(f_c), Core.Typeof(α), Core.Typeof(β) + CombinedMeasure{FC,MA,MB}(f_c, α, β) +end + +@inline function _generic_mcombine_impl_stage2(f_c, α::Dirac, β::Dirac) + Dirac(f_c(α.x, β.x)) +end + +""" + struct CombinedMeasure <: AbstractMeasure + +Represents a combination of two measures. + +User code should not create instances of `CombinedMeasure` directly, but should call +[`mcombine(f_c, α, β)`](@ref) instead. +""" + +struct CombinedMeasure{FC,MA<:AbstractMeasure,MB<:AbstractMeasure} <: AbstractMeasure + f_c::FC + α::MA + β::MB +end + + +@inline insupport(μ::CombinedMeasure, ab) = NoFastInsupport{typeof(μ)}() + +@inline getdof(μ::CombinedMeasure) = getdof(μ.α) + getdof(μ.β) +@inline fast_dof(μ::CombinedMeasure) = fast_dof(μ.α) + fast_dof(μ.β) + +# Bypass `checked_arg`, would require require splitting ab: +@inline checked_arg(::CombinedMeasure, ab) = ab + +rootmeasure(μ::CombinedMeasure) = mcombine(μ.f_c, rootmeasure(μ.α), rootmeasure(μ.β)) + +basemeasure(μ::CombinedMeasure) = mcombine(μ.f_c, basemeasure(μ.α), basemeasure(μ.β)) + +function logdensity_def(μ::CombinedMeasure, ab) + # Use tpmeasure_split_combined to avoid duplicate calculation of transportmeasure(α): + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + return logdensity_def(tpm_α, a) + logdensity_def(μ.β, b) +end + +# Specialize logdensityof directly to avoid creating temporary combined base measures: +function logdensityof(μ::CombinedMeasure, ab) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + return logdensityof(tpm_α, a) + logdensityof(μ.β, b) +end + + +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, μ::CombinedMeasure) where {T<:Real} + a = rand(rng, T, μ.α) + b = rand(rng, T, μ.β) + return μ.f_c(a, b) +end + +function Base.rand(rng::Random.AbstractRNG, μ::CombinedMeasure) + a = rand(rng, μ.α) + b = rand(rng, μ.β) + return μ.f_c(a, b) +end + + +function transport_to_mvstd(ν_inner::StdMeasure, μ::CombinedMeasure, ab) + tpm_α, a, b = tpmeasure_split_combined(μ.f_c, μ.α, ab) + y1 = transport_to_mvstd(ν_inner, tpm_α, a) + y2 = transport_to_mvstd(ν_inner, μ.β, b) + return vcat(y1, y2) +end + + +function transport_from_mvstd_with_rest(ν::CombinedMeasure, μ_inner::StdMeasure, x) + a, x2 = transport_from_mvstd_with_rest(ν.α, μ_inner, x) + b, x_rest = transport_from_mvstd_with_rest(ν.β, μ_inner, x2) + return ν.f_c(a, b), x_rest +end diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 6dfd164f..863ea7d4 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -8,12 +8,12 @@ abstract type AbstractLikelihood end # ifelse(insupport(ℓ, p), t, f)() # end -# insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) +# insupport(ℓ::AbstractLikelihood, p) = insupport(_eval_k(ℓ, p), ℓ.x) @doc raw""" - Likelihood(k::AbstractTransitionKernel, x) + Likelihood(k, x) -"Observe" a value `x`, yielding a function from the parameters to ℝ. +Default result of [`likelihoodof(k, x)`](@ref). Likelihoods are most commonly used in conjunction with an existing _prior_ measure to yield a new measure, the _posterior_. In Bayes's Law, we have @@ -64,39 +64,12 @@ With several parameters, things work as expected: --------- - Likelihood(M<:ParameterizedMeasure, constraint::NamedTuple, x) - -In some cases the measure might have several parameters, and we may want the -(log-)likelihood with respect to some subset of them. In this case, we can use -the three-argument form, where the second argument is a constraint. For example, - - julia> ℓ = Likelihood(Normal{(:μ,:σ)}, (σ=3.0,), 2.0) - Likelihood(Normal{(:μ, :σ), T} where T, (σ = 3.0,), 2.0) - -Similarly to the above, we have - - julia> density_def(ℓ, (μ=2.0,)) - 0.3333333333333333 - - julia> logdensity_def(ℓ, (μ=2.0,)) - -1.0986122886681098 - - julia> density_def(ℓ, 2.0) - 0.3333333333333333 - - julia> logdensity_def(ℓ, 2.0) - -1.0986122886681098 - ------------------------ - Finally, let's return to the expression for Bayes's Law, -``P(θ|x) ∝ P(θ) P(x|θ)`` +``P(θ|x) ∝ P(x|θ) P(θ)`` -The product on the right side is computed pointwise. To work with this in -MeasureBase, we have a "pointwise product" `⊙`, which takes a measure and a -likelihood, and returns a new measure, that is, the unnormalized posterior that -has density ``P(θ) P(x|θ)`` with respect to the base measure of the prior. +In measure theory, the product on the right side is the Lebesgue integral +of the likelihood with respect to the prior. For example, say we have @@ -104,24 +77,30 @@ For example, say we have x ~ Normal(μ,σ) σ = 1 -and we observe `x=3`. We can compute the posterior measure on `μ` as +and we observe `x=3`. We can compute the (non-normalized) posterior measure on +`μ` as - julia> post = Normal() ⊙ Likelihood(Normal{(:μ, :σ)}, (σ=1,), 3) - Normal() ⊙ Likelihood(Normal{(:μ, :σ), T} where T, (σ = 1,), 3) - - julia> logdensity_def(post, 2) - -2.5 + julia> prior = Normal() + julia> likelihood = Likelihood(μ -> Normal(μ, 1), 3) + julia> post = mintegrate(likelihood, prior) + julia> post isa MeasureBase.DensityMeasure + true + julia> logdensity_rel(post, Lebesgue(), 2) + -4.337877066409345 """ struct Likelihood{K,X} <: AbstractLikelihood k::K x::X - Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x) - Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x) - Likelihood(μ, x) = Likelihood(kernel(μ), x) + Likelihood{K,X}(k, x) where {K,X} = new{K,X}(k, x) end -(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) +# For type stability, in case k is a type (resp. a constructor): +Likelihood(k, x::X) where {X} = Likelihood{Core.Typeof(k),X}(k, x) + +(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(_eval_k(lik, p), lik.x)) + +_eval_k(ℓ::AbstractLikelihood, p) = asmeasure(ℓ.k(p)) DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() @@ -136,72 +115,100 @@ function Base.show(io::IO, ℓ::Likelihood) Pretty.pprint(io, ℓ) end -insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) +insupport(ℓ::AbstractLikelihood, p) = insupport(_eval_k(ℓ, p), ℓ.x) @inline function logdensityof(ℓ::AbstractLikelihood, p) - logdensityof(ℓ.k(p), ℓ.x) + logdensityof(_eval_k(ℓ, p), ℓ.x) end @inline function unsafe_logdensityof(ℓ::AbstractLikelihood, p) - return unsafe_logdensityof(ℓ.k(p), ℓ.x) + return unsafe_logdensityof(_eval_k(ℓ, p), ℓ.x) end # basemeasure(ℓ::Likelihood) = @error "Likelihood requires local base measure" export likelihoodof -""" - likelihoodof(k::AbstractTransitionKernel, x; constraints...) - likelihoodof(k::AbstractTransitionKernel, x, constraints::NamedTuple) +@doc raw""" + likelihoodof(k, x) -A likelihood is *not* a measure. Rather, a likelihood acts on a measure, through -the "pointwise product" `⊙`, yielding another measure. -""" -function likelihoodof end +Returns the likelihood of observing `x` under a family of probability +measures that is generated by a transition kernel `k(θ)`. -likelihoodof(k, x, ::NamedTuple{()}) = Likelihood(k, x) +`k(θ)` maps points in the parameter space to measures (resp. objects that can +be converted to measures) on a implicit set `Χ` that contains values like `x`. -likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +`likelihoodof(k, x)` returns a likelihood object. A likelihhood is **not** a +measure, it is a function from the parameter space to `ℝ₊`. Likelihood +objects can also be interpreted as "generic densities" (but **not** as +probability densities). -likelihoodof(k, x, pars::NamedTuple) = likelihoodof(kernel(k, pars), x) +`likelihoodof(k, x)` implicitly chooses `ξ = rootmeasure(k(θ))` as the +reference measure on the observation set `Χ`. Note that this implicit +`ξ` **must** be independent of `θ`. -likelihoodof(k::AbstractTransitionKernel, x) = Likelihood(k, x) +`ℒₓ = likelihoodof(k, x)` has the mathematical interpretation -export log_likelihood_ratio +```math +\mathcal{L}_x(\theta) = \frac{\rm{d}\, k(\theta)}{\rm{d}\, \chi}(x) +``` +`likelihoodof` must return an object that implements the +[`DensityInterface`](https://github.com/JuliaMath/DensityInterface.jl)` API +and `ℒₓ = likelihoodof(k, x)` must satisfy + +```julia +log(ℒₓ(θ)) == logdensityof(ℒₓ, θ) ≈ logdensityof(k(θ), x) + +DensityKind(ℒₓ) isa IsDensity +``` + +By default, an instance of [`MeasureBase.Likelihood`](@ref) is returned. """ - log_likelihood_ratio(ℓ::Likelihood, p, q) +function likelihoodof end -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is computed as +likelihoodof(k, x) = Likelihood(k, x) - logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +############################################################################### +# At the least, we need to think through in some more detail whether +# (log-)likelihood ratios expressed in this way are correct and useful. For now +# this code is commented out; we may remove it entirely in the future. -Since `logdensity_rel` can leave common base measure unevaluated, this can be -more efficient than +# export log_likelihood_ratio - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) -# likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is computed as -export likelihood_ratio +# logdensity_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x) -""" - likelihood_ratio(ℓ::Likelihood, p, q) +# Since `logdensity_rel` can leave common base measure unevaluated, this can be +# more efficient than -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is equal to +# logdensityof(_eval_k(ℓ, p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x) - density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# # likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) -but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. -Since `density_rel` can leave common base measure unevaluated, this can be -more efficient than +# export likelihood_ratio - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -function likelihood_ratio(ℓ::Likelihood, p, q) - exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) -end +# """ +# likelihood_ratio(ℓ::Likelihood, p, q) + +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is equal to + +# density_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x) + +# but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. +# Since `density_rel` can leave common base measure unevaluated, this can be +# more efficient than + +# logdensityof(_eval_k(ℓ, p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# function likelihood_ratio(ℓ::Likelihood, p, q) +# exp(ULogarithmic, logdensity_rel(_eval_k(ℓ, p), ℓ.k(q), ℓ.x)) +# end diff --git a/src/combinators/pointwise.jl b/src/combinators/pointwise.jl deleted file mode 100644 index 778e7f4e..00000000 --- a/src/combinators/pointwise.jl +++ /dev/null @@ -1,30 +0,0 @@ -export ⊙ - -struct PointwiseProductMeasure{P,L} <: AbstractMeasure - prior::P - likelihood::L -end - -iterate(p::PointwiseProductMeasure, i = 1) = iterate((p.prior, p.likelihood), i) - -function Pretty.tile(d::PointwiseProductMeasure) - Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep = " ⊙ ") -end - -⊙(prior, ℓ) = pointwiseproduct(prior, ℓ) - -@inbounds function insupport(d::PointwiseProductMeasure, p) - prior, ℓ = d - istrue(insupport(prior, p)) && istrue(insupport(ℓ, p)) -end - -@inline function logdensity_def(d::PointwiseProductMeasure, p) - prior, ℓ = d - unsafe_logdensityof(ℓ, p) -end - -basemeasure(d::PointwiseProductMeasure) = d.prior - -function gentype(d::PointwiseProductMeasure) - gentype(d.prior) -end diff --git a/src/combinators/power.jl b/src/combinators/power.jl index a7fa24f8..d613cf60 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -1,5 +1,6 @@ import Base + export PowerMeasure """ @@ -11,6 +12,8 @@ the product determines the dimensionality of the resulting support. Note that power measures are only well-defined for integer powers. The nth power of a measure μ can be written μ^n. + +See also [`pwr_base`](@ref), [`pwr_axes`](@ref) and [`pwr_size`](@ref). """ struct PowerMeasure{M,A} <: AbstractProductMeasure parent::M @@ -20,6 +23,31 @@ end maybestatic_length(μ::PowerMeasure) = prod(maybestatic_size(μ)) maybestatic_size(μ::PowerMeasure) = map(maybestatic_length, μ.axes) + +""" + MeasureBase.pwr_base(μ::PowerMeasure) + +Returns `ν` for `μ = ν^axs` +""" +pwr_base(μ::PowerMeasure) = μ.parent + + +""" + MeasureBase.pwr_axes(μ::PowerMeasure) + +Returns `axs` for `μ = ν^axs`, `axs` being a tuple of integer ranges. +""" +pwr_axes(μ::PowerMeasure) = μ.axes + + +""" + MeasureBase.pwr_size(μ::PowerMeasure) + +Returns `sz` for `μ = ν^sz`, `sz` being a tuple of integers. +""" +pwr_size(μ::PowerMeasure) = map(maybestatic_length, μ.axes) + + function Pretty.tile(μ::PowerMeasure) sz = length.(μ.axes) arg1 = Pretty.tile(μ.parent) @@ -27,36 +55,46 @@ function Pretty.tile(μ::PowerMeasure) return Pretty.pair_layout(arg1, arg2; sep = " ^ ") end -# ToDo: Make rand return static arrays for statically-sized power measures. +# ToDo: Make rand and testvalue return static arrays for statically-sized power measures. function _cartidxs(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} CartesianIndices(map(_dynamic, axs)) end -function Base.rand( - rng::AbstractRNG, - ::Type{T}, - d::PowerMeasure{M}, -) where {T,M<:AbstractMeasure} - map(_cartidxs(d.axes)) do _ - rand(rng, T, d.parent) - end +function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} + map(_ -> rand(rng, T, d.parent), _cartidxs(d.axes)) end -function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} - map(_cartidxs(d.axes)) do _ - rand(rng, d.parent) - end +function Base.rand(rng::AbstractRNG, d::PowerMeasure) + map(_ -> rand(rng, d.parent), _cartidxs(d.axes)) end -@inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz) -@inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs +function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure{M,<:Tuple{Vararg{StaticOneTo}}}) where {T,M} + sz = pwr_size(d) + base_d = pwr_base(d) + broadcast(_ -> rand(rng, T, base_d), MeasureBase.maybestatic_fill(nothing, sz)) +end -@inline function powermeasure(x::T, sz::Tuple{Vararg{Any,N}}) where {T,N} - PowerMeasure(x, _pm_axes(sz)) +function Base.rand(rng::AbstractRNG, d::PowerMeasure{M,<:Tuple{Vararg{StaticOneTo}}}) where M + sz = pwr_size(d) + base_d = pwr_base(d) + broadcast(_ -> rand(rng, base_d), MeasureBase.maybestatic_fill(nothing, sz)) end -marginals(d::PowerMeasure) = fill_with(d.parent, d.axes) +function testvalue(::Type{T}, d::PowerMeasure) where {T} + map(_ -> testvalue(T, d.parent), _cartidxs(d.axes)) +end + +function testvalue(d::PowerMeasure) + map(_ -> testvalue(d.parent), _cartidxs(d.axes)) +end + + +@inline _pm_axes(::Tuple{}) = () +@inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz) +@inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs + +marginals(d::PowerMeasure) = maybestatic_fill(d.parent, d.axes) function Base.:^(μ::AbstractMeasure, dims::Tuple{Vararg{AbstractArray,N}}) where {N} powermeasure(μ, dims) @@ -78,26 +116,21 @@ params(d::PowerMeasure) = params(first(marginals(d))) basemeasure(d.parent)^d.axes end -@inline function logdensity_def(d::PowerMeasure{M}, x) where {M} - parent = d.parent +@inline logdensity_def(d::PowerMeasure, x) = _pwr_logdensity_def(pwr_base(d), x, prod(pwr_size(d))) + +@inline _pwr_logdensity_def(d_base, x, ::Integer, ::StaticInteger{0}) = static(false) + +@inline function _pwr_logdensity_def(d_base, x, ::IntegerLike) sum(x) do xj - logdensity_def(parent, xj) + logdensity_def(d_base, xj) end end -@inline function logdensity_def(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N} - parent = d.parent - sum(1:N) do j - @inbounds logdensity_def(parent, x[j]) - end -end +# ToDo: Specialized version of _pwr_logdensity_def for statically-sized power measures + +# ToDo: Re-enable this? +# _pwr_logdensity_def(::PowerMeasure{P}, x, ::IntegerLike) where {P<:PrimitiveMeasure} = static(0.0) -@inline function logdensity_def( - d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}}, - x, -) where {M,N} - static(0.0) -end @inline function insupport(μ::PowerMeasure, x) p = μ.parent @@ -107,23 +140,33 @@ end end end +_all(A) = all(A) +_all(::AbstractArray{NoFastInsupport{T}}) where T = NoFastInsupport{T}() + + @inline function insupport(μ::PowerMeasure, x::AbstractArray) p = μ.parent - all(x) do xj + insupp = broadcast(x) do xj # https://github.com/SciML/Static.jl/issues/36 dynamic(insupport(p, xj)) end + _all(insupp) end -@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) +@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(pwr_size(μ)) +@inline fast_dof(μ::PowerMeasure) = fast_dof(μ.parent) * prod(pwr_size(μ)) @inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N} static(0) end +@inline function fast_dof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N} + static(0) +end + @propagate_inbounds function checked_arg(μ::PowerMeasure, x::AbstractArray{<:Any}) @boundscheck begin - sz_μ = map(length, μ.axes) + sz_μ = pwr_size(μ) sz_x = size(x) if sz_μ != sz_x throw(ArgumentError("Size of variate doesn't match size of power measure")) @@ -138,12 +181,14 @@ end massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes) -logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0) -# To avoid ambiguities -function logdensity_def( - ::PowerMeasure{P,Tuple{Vararg{Static.SOneTo{0},N}}}, - x, -) where {P<:PrimitiveMeasure,N} - static(0.0) -end +""" + MeasureBase.StdPowerMeasure{MU<:StdMeasure,N} + +Represents and N-dimensional power of the standard measure `MU()`. +""" +const StdPowerMeasure{MU<:StdMeasure,N} = PowerMeasure{MU,<:NTuple{N,UnitRangeFromOne}} + +# ToDo: Fast specialized rand for static and non-static StdPowerMeasure! + +# ToDo: Define mrand and dispatch Base.rand to mrand to burden Base.rand with less methods! diff --git a/src/combinators/powerweighted.jl b/src/combinators/powerweighted.jl deleted file mode 100644 index 47f50da4..00000000 --- a/src/combinators/powerweighted.jl +++ /dev/null @@ -1,37 +0,0 @@ -export ↑ - -struct PowerWeightedMeasure{M,A} <: AbstractMeasure - parent::M - exponent::A -end - -logdensity_def(d::PowerWeightedMeasure, x) = d.exponent * logdensity_def(d.parent, x) - -basemeasure(d::PowerWeightedMeasure, x) = basemeasure(d.parent, x)↑d.exponent - -basemeasure(d::PowerWeightedMeasure) = basemeasure(d.parent)↑d.exponent - -function powerweightedmeasure(d, α) - isone(α) && return d - PowerWeightedMeasure(d, α) -end - -(d::AbstractMeasure)↑α = powerweightedmeasure(d, α) - -insupport(d::PowerWeightedMeasure, x) = insupport(d.parent, x) - -function Base.show(io::IO, d::PowerWeightedMeasure) - print(io, d.parent, " ↑ ", d.exponent) -end - -function powerweightedmeasure(d::PowerWeightedMeasure, α) - powerweightedmeasure(d.parent, α * d.exponent) -end - -function powerweightedmeasure(d::WeightedMeasure, α) - weightedmeasure(α * d.logweight, powerweightedmeasure(d.base, α)) -end - -function Pretty.tile(d::PowerWeightedMeasure) - Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.exponent), sep = " ↑ ") -end diff --git a/src/combinators/product.jl b/src/combinators/product.jl index cb7a0aaf..2e8caee7 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -1,34 +1,45 @@ -export ProductMeasure - using MappedArrays using MappedArrays: ReadonlyMultiMappedArray using Base: @propagate_inbounds import Base using FillArrays +using Random: rand!, GLOBAL_RNG, AbstractRNG -export AbstractProductMeasure -abstract type AbstractProductMeasure <: AbstractMeasure end +""" + struct MeasureBase.ProductMeasure{M} <: AbstractProductMeasure + +Represents a products of measures. -function Pretty.tile(μ::AbstractProductMeasure) - result = Pretty.literal("ProductMeasure(") - result *= Pretty.tile(marginals(μ)) - result *= Pretty.literal(")") +´ProductMeasure` wraps a collection of measures, this collection then +becomes the collection of the marginal measures of the `ProductMeasure`. + +User code should not instantiate `ProductMeasure` directly, but should call +[`productmeasure`](@ref) instead. +""" +struct ProductMeasure{M} <: AbstractProductMeasure + marginals::M +end + +function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple} + Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = " ⊗ ") end -massof(m::AbstractProductMeasure) = prod(massof, marginals(m)) +marginals(μ::ProductMeasure) = μ.marginals + +proxy(μ::ProductMeasure{<:Fill}) = powermeasure(_fill_value(marginals(μ)), _fill_axes(marginals(μ))) -export marginals -function Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure) - marginals(a) == marginals(b) +# TODO: Better `map` support in MappedArrays +_map(f, args...) = map(f, args...) +_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(fchain((x.f, f)), x.data) + +function testvalue(::Type{T}, μ::ProductMeasure) where {T} + _map(m -> testvalue(T, m), marginals(μ)) end -Base.length(μ::AbstractProductMeasure) = length(marginals(μ)) -Base.size(μ::AbstractProductMeasure) = size(marginals(μ)) -basemeasure(d::AbstractProductMeasure) = productmeasure(map(basemeasure, marginals(d))) -function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractProductMeasure) where {T} +function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure) where {T} mar = marginals(d) _rand_product(rng, T, mar, eltype(mar)) end @@ -72,119 +83,70 @@ function _rand_product( end |> collect end -@inline function logdensity_def(d::AbstractProductMeasure, x) - mapreduce(logdensity_def, +, marginals(d), x) -end -struct ProductMeasure{M} <: AbstractProductMeasure - marginals::M +@inline function logdensity_def(μ::ProductMeasure, x) + _marginals_density_op(logdensity_def, marginals(μ), x) +end +@inline function unsafe_logdensityof(μ::ProductMeasure, x) + _marginals_density_op(unsafe_logdensityof, marginals(μ), x) end - @inline function logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x) - mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x) + _marginals_density_op(logdensity_rel, marginals(μ), marginals(ν), x) end -function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple} - Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = " ⊗ ") +function _marginals_density_op(density_op::F, marginals_μ, x) where F + mapreduce(density_op, +, marginals_μ, x) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::Tuple, x::Tuple) where F + # For tuples, `mapreduce` can have trouble with type inference + sum(map(density_op, marginals_μ, x)) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, x::NamedTuple) where {F,names} + _marginals_density_op(density_op, values(marginals_μ), values(NamedTuple{names}(x))) end -# For tuples, `mapreduce` has trouble with type inference -@inline function logdensity_def(d::ProductMeasure{T}, x) where {T<:Tuple} - ℓs = map(logdensity_def, marginals(d), x) - sum(ℓs) +function _marginals_density_op(density_op::F, marginals_μ, marginals_ν, x) where F + mapreduce(density_op, +, marginals_μ, marginals_ν, x) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::Tuple, marginals_ν::Tuple, x::Tuple) where F + # For tuples, `mapreduce` can have trouble with type inference + sum(map(density_op, marginals_μ, marginals_ν, x)) +end +@inline function _marginals_density_op(density_op::F, marginals_μ::NamedTuple{names}, marginals_ν::NamedTuple, x::NamedTuple) where {F,names} + _marginals_density_op(density_op, values(marginals_μ), values(NamedTuple{names}(marginals_ν)), values(NamedTuple{names}(x))) end -@generated function logdensity_def(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T} - k1 = QuoteNode(first(N)) - q = quote - m = marginals(d) - ℓ = logdensity_def(getproperty(m, $k1), getproperty(x, $k1)) - end - for k in Base.tail(N) - k = QuoteNode(k) - qk = :(ℓ += logdensity_def(getproperty(m, $k), getproperty(x, $k))) - push!(q.args, qk) - end - return q -end +@inline basemeasure(μ::ProductMeasure) = _marginals_basemeasure(marginals(μ)) -# @generated function basemeasure(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T} -# q = quote -# m = marginals(d) -# end -# for k in N -# qk = QuoteNode(k) -# push!(q.args, :($k = basemeasure(getproperty(m, $qk)))) -# end +_marginals_basemeasure(marginals_μ) = productmeasure(map(basemeasure, marginals_μ)) -# vals = map(x -> Expr(:(=), x,x), N) -# push!(q.args, Expr(:tuple, vals...)) -# return q -# end -function basemeasure(μ::ProductMeasure{Base.Generator{I,F}}) where {I,F} - mar = marginals(μ) - T = Core.Compiler.return_type(mar.f, Tuple{eltype(mar.iter)}) - B = Core.Compiler.return_type(basemeasure, Tuple{T}) - _basemeasure(μ, B, static(Base.issingletontype(B))) -end +# I <: Base.Generator -function basemeasure(μ::ProductMeasure{A}) where {T,A<:AbstractMappedArray{T}} +function _marginals_basemeasure(marginals_μ::Base.Generator{I,F}) where {I,F} + T = Core.Compiler.return_type(marginals_μ.f, Tuple{eltype(marginals_μ.iter)}) B = Core.Compiler.return_type(basemeasure, Tuple{T}) - _basemeasure(μ, B, static(Base.issingletontype(B))) + _marginals_basemeasure_impl(marginals_μ, B, static(Base.issingletontype(B))) end -function _basemeasure(μ::ProductMeasure, ::Type{B}, ::True) where {B} - return instance(B)^axes(marginals(μ)) +function _marginals_basemeasure(marginals_μ::AbstractMappedArray{T}) where {T} + B = Core.Compiler.return_type(basemeasure, Tuple{T}) + _marginals_basemeasure_impl(marginals_μ, B, static(Base.issingletontype(B))) end -function _basemeasure( - μ::ProductMeasure{A}, - ::Type{B}, - ::False, -) where {T,A<:AbstractMappedArray{T},B} - mar = marginals(μ) - productmeasure(mappedarray(basemeasure, mar)) +function _marginals_basemeasure_impl(marginals_μ, ::Type{B}, ::True) where {B} + instance(B)^axes(marginals_μ) end -function _basemeasure( - μ::ProductMeasure{Base.Generator{I,F}}, - ::Type{B}, - ::False, -) where {I,F,B} - mar = marginals(μ) - productmeasure(Base.Generator(basekernel(mar.f), mar.iter)) +function _marginals_basemeasure_impl(marginals_μ::AbstractMappedArray{T}, ::Type{B}, ::False) where {T,B} + productmeasure(mappedarray(basemeasure, marginals_μ)) end -marginals(μ::ProductMeasure) = μ.marginals - -# TODO: Better `map` support in MappedArrays -_map(f, args...) = map(f, args...) -_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(fchain((x.f, f)), x.data) - -function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} - _map(m -> testvalue(T, m), marginals(d)) +function _marginals_basemeasure_impl(marginals_μ::Base.Generator{I,F}, ::Type{B}, ::False) where {I,F,B} + productmeasure(Base.Generator(basekernel(marginals_μ.f), marginals_μ.iter)) end -export ⊗ - -""" - ⊗(μs::AbstractMeasure...) - -`⊗` is a binary operator for building product measures. This satisfies the law - -``` - basemeasure(μ ⊗ ν) == basemeasure(μ) ⊗ basemeasure(ν) -``` -""" -⊗(μs::AbstractMeasure...) = productmeasure(μs) - -############################################################################### -# I <: Base.Generator - -export rand! -using Random: rand!, GLOBAL_RNG, AbstractRNG @propagate_inbounds function Random.rand!( rng::AbstractRNG, @@ -199,8 +161,6 @@ using Random: rand!, GLOBAL_RNG, AbstractRNG return x end -export rand! -using Random: rand!, GLOBAL_RNG function _rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure, mar::AbstractArray) where {T} elT = typeof(rand(rng, T, first(mar))) @@ -210,7 +170,7 @@ function _rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure, mar::AbstractArra rand!(rng, d, x) end -@inline function insupport(d::AbstractProductMeasure, x::AbstractArray) +@inline function insupport(d::ProductMeasure, x::AbstractArray) mar = marginals(d) # We might get lucky and know statically that everything is inbounds T = Core.Compiler.return_type(insupport, Tuple{eltype(mar),eltype(x)}) @@ -219,14 +179,18 @@ end end end -@inline function insupport(d::AbstractProductMeasure, x) +@inline function insupport(d::ProductMeasure, x) for (mj, xj) in zip(marginals(d), x) - dynamic(insupport(mj, xj)) || return false + insup = dynamic(insupport(mj, xj)) + if insup isa NoFastInsupport || insup == false + return insup + end end return true end -getdof(d::AbstractProductMeasure) = mapreduce(getdof, +, marginals(d)) +getdof(d::ProductMeasure) = sum(getdof, marginals(d)) +fast_dof(d::ProductMeasure) = sum(fast_dof, marginals(d)) function checked_arg(μ::ProductMeasure{<:NTuple{N,Any}}, x::NTuple{N,Any}) where {N} map(checked_arg, marginals(μ), x) @@ -238,3 +202,4 @@ function checked_arg( ) where {names} NamedTuple{names}(map(checked_arg, values(marginals(μ)), values(x))) end + diff --git a/src/combinators/product_transport.jl b/src/combinators/product_transport.jl new file mode 100644 index 00000000..f87219cf --- /dev/null +++ b/src/combinators/product_transport.jl @@ -0,0 +1,313 @@ +""" + transport_to(ν, ::Type{MU}) where {NU<:StdMeasure} + transport_to(::Type{NU}, μ) where {NU<:StdMeasure} + +As a user convenience, a standard measure type like [`StdUniform`](@ref), +[`StdExponential`](@ref), [`StdNormal`](@ref) or [`StdLogistic`](@ref) +may be used directly as the source or target a measure transport. + +Depending on [`some_getdof(μ)`](@ref) (resp. `ν`), an instance of the +standard measure itself or a power of it will be automatically chosen as +the transport partner. + +Example: + +```julia +transport_to(StdNormal, μ) +transport_to(ν, StdNormal) +``` +""" +function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} + transport_to(ν, _std_tp_partner(MU, ν)) +end + +function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} + transport_to(_std_tp_partner(NU, μ), μ) +end + +function transport_to(::Type{NU}, ::Type{MU}) where {NU<:StdMeasure,MU<:StdMeasure} + throw(ArgumentError("Can't construct a transport function between the type of two standard measures, need a measure instance on one side")) +end + +_std_tp_partner(::Type{M}, μ) where {M<:StdMeasure} = _std_tp_partner_bydof(M, some_dof(μ)) +_std_tp_partner_bydof(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() +_std_tp_partner_bydof(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof +function _std_tp_partner_bydof(::Type{M}, dof::AbstractNoDOF{MU}) where {M<:StdMeasure,MU} + throw(ArgumentError("Can't determine a standard transport partner for measures of type $(nameof(typeof(MU)))")) +end + + +# For transport, always pull a PowerMeasure back to one-dimensional PowerMeasure first: + +transport_origin(μ::PowerMeasure{<:Any,N}) where N = transport_origin(μ.parent)^prod(pwr_size(μ)) + +function from_origin(μ::PowerMeasure{<:Any,N}, x_origin) where N + # Sanity check, should never fail: + @assert x_origin isa AbstractVector + return maybestatic_reshape(x_origin, pwr_size(μ)...) +end + + +# A one-dimensional PowerMeasure has an origin if its parent has an origin: + +transport_origin(μ::PowerMeasure{<:AbstractMeasure,1}) = _pwr_origin(typeof(μ), pwr_base(μ), pwr_axes(μ)) +_pwr_origin(::Type{MU}, parent_origin, axes) where MU = parent_origin^axes +_pwr_origin(::Type{MU}, ::NoTransportOrigin, axes) where MU = NoTransportOrigin{MU} + +function from_origin(μ::PowerMeasure{<:AbstractMeasure,1}, x_origin) + # Sanity check, should never fail: + @assert x_origin isa AbstractVector + from_origin.(Ref(μ.parent), x_origin) +end + +# Specialize for case of equal bases. Because of StdPowerMeasure methods below +# specify `,1`, we need extra methods to avoid ambiguity + +function transport_def(ν::StdPowerMeasure{MU}, μ::StdPowerMeasure{MU}, x) where MU + reshape(x, ν.axes) +end + +function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{MU}, x) where MU + reshape(x, ν.axes) +end + +function transport_def(ν::StdPowerMeasure{MU}, μ::StdPowerMeasure{MU,1}, x) where MU + reshape(x, ν.axes) +end + +function transport_def(ν::StdPowerMeasure{MU,1}, μ::StdPowerMeasure{MU,1}, x) where MU + reshape(x, ν.axes) +end + +# Transport between univariate standard measures and 1-dim power measures of size one: + +function transport_def(ν::StdMeasure, μ::StdPowerMeasure{MU,1}, x) where {MU} + return transport_def(ν, μ.parent, only(x)) +end + +function transport_def(ν::StdPowerMeasure{NU,1}, μ::StdMeasure, x) where {NU} + sz_ν = pwr_size(ν) + @assert prod(sz_ν) == 1 + return maybestatic_fill(transport_def(ν.parent, μ, x), sz_ν) +end + +function transport_def(ν::StdPowerMeasure{NU,1}, μ::StdPowerMeasure{MU,1}, x) where {NU,MU} + reshape(transport_to(ν.parent, μ.parent).(x), ν.axes) +end + + +# Transport to a multivariate standard measure from any measure: + +function transport_def(ν::StdPowerMeasure{MU,1}, μ::AbstractMeasure, x) where MU + ν_inner = pwr_base(ν) + transport_to_mvstd(ν_inner, μ, x) +end + +function transport_to_mvstd(ν_inner::StdMeasure, μ::AbstractMeasure, x) + return _to_mvstd_withdof(ν_inner, μ, fast_dof(μ), x) +end + +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, dof_μ::IntegerLike, x) + y = transport_def(ν_inner^dof_μ, μ, x) + return y +end + +function _to_mvstd_withdof(ν_inner::StdMeasure, μ::AbstractMeasure, ::AbstractNoDOF, x) + _to_mvstd_withorigin(ν_inner, μ, transport_origin(μ), x) +end + +function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, μ_origin, x) + x_origin = transport_to_mvstd(ν_inner, μ_origin, x) + from_origin(μ, x_origin) +end + +function _to_mvstd_withorigin(ν_inner::StdMeasure, μ::AbstractMeasure, ::NoTransportOrigin, x) + throw(ArgumentError("Don't know how to transport values of type $(nameof(typeof(x))) from $(nameof(typeof(μ))) to a power of $(nameof(typeof(ν_inner)))")) +end + + +# Transport from a multivariate standard measure to any measure: + +function transport_def(ν::AbstractMeasure, μ::StdPowerMeasure{MU,1}, x) where {MU} + μ_inner = pwr_base(μ) + _transport_from_mvstd(ν, μ_inner, x) +end + +function _transport_from_mvstd(ν::AbstractMeasure, μ_inner::StdMeasure, x) + y, x_rest = transport_from_mvstd_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Input value too long during transport")) + end + return y +end + +function transport_from_mvstd_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) + dof_ν = fast_dof(ν) + return _from_mvstd_with_rest_withdof(ν, dof_ν, μ_inner, x) +end + +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, dof_ν::IntegerLike, μ_inner::StdMeasure, x) + len_x = length(eachindex(x)) + + # Since we can't check DOF of original Bind, we could "run out x" if + # the original x was too short. `transport_to` below will detect this, but better + # throw a more informative exception here: + if len_x < dof_ν + throw(ArgumentError("Variate too short during transport involving Bind")) + end + + x_inner_dof, x_rest = _split_after(x, dof_ν) + y = transport_to(ν, μ_inner^dof_ν, x_inner_dof) + return y, x_rest +end + +function _from_mvstd_with_rest_withdof(ν::AbstractMeasure, ::AbstractNoDOF, μ_inner::StdMeasure, x) + _from_mvstd_with_rest_withorigin(ν, transport_origin(ν), μ_inner, x) +end + +function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, ν_origin, μ_inner::StdMeasure, x) + x_origin, x_rest = transport_from_mvstd_with_rest(ν_origin, x, μ_inner) + from_origin(ν, x_origin), x_rest +end + +function _from_mvstd_with_rest_withorigin(ν::AbstractMeasure, ::NoTransportOrigin, μ_inner::StdMeasure, x) + throw(ArgumentError("Don't know how to transport value of type $(nameof(typeof(x))) from power of $(nameof(typeof(μ_inner))) to $(nameof(typeof(ν)))")) +end + + +# Transport between a standard measure and Dirac: + +@inline transport_from_mvstd_with_rest(ν::Dirac, ::StdMeasure, x::Any) = ν.x, x + +@inline transport_to_mvstd(::StdMeasure, ::Dirac, ::Any) = Zeros{Bool}(0) + + + + + +@inline transport_origin(μ::ProductMeasure) = _marginals_tp_origin(marginals(μ)) +@inline from_origin(μ::ProductMeasure, x_origin) = _marginals_from_origin(marginals(μ), x_origin) + +_marginals_tp_origin(::Ms) where Ms = NoTransportOrigin{ProductMeasure{Ms}}() + + +# Pull back from a product over a Fill to a power measure: + +_marginals_tp_origin(marginals_μ::Fill) = marginals_μ.value^marginals_μ.axes +_marginals_from_origin(::Fill, x_origin) = x_origin + + +# Pull back from a NamedTuple product measure to a Tuple product measure: +# +# Maybe ToDo (breaking): For transport between NamedTuple-marginals we could +# match names where possible, even if given in different order, and transport +# between the remaining non-matching names in the order given. This may not +# be worth the additional complexity, though, since transport is typically +# used with a (power of a) standard measure on one side. + +_marginals_tp_origin(marginals_μ::NamedTuple{names}) where names = productmeasure(values(marginals_μ)) +_marginals_from_origin(::NamedTuple{names}, x_origin::NamedTuple) where names = NamedTuple{names}(x_origin) + + +# Transport between two instances of ProductMeasure: + +transport_def(ν::ProductMeasure, μ::ProductMeasure, x) = _marginal_transport_def(marginals(ν), marginals(μ), x) + +function _marginal_transport_def(marginals_ν, marginals_μ, x) + @assert size(marginals_ν) == size(marginals_μ) == size(x) # Sanity check, should not fail + transport_def.(marginals_ν, marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N + @assert x isa Tuple{Vararg{AbstractMeasure,N}} # Sanity check, should not fail + map(transport_def, marginals_ν, marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::AbstractVector{<:AbstractMeasure}, marginals_μ::Tuple{Vararg{AbstractMeasure,N}}, x) where N + _marginal_transport_def(_as_tuple(marginals_ν, Val(N)), marginals_μ, x) +end + +function _marginal_transport_def(marginals_ν::Tuple{Vararg{AbstractMeasure,N}}, marginals_μ::AbstractVector{<:AbstractMeasure}, x) where N + _marginal_transport_def(marginals_ν, _as_tuple(marginals_μ, Val(N)), _as_tuple(x, Val(N))) +end + + + +# Transport from ProductMeasure to StdMeasure type: + +function transport_to_mvstd(ν_inner::StdMeasure, μ::ProductMeasure, x) + _marginals_to_mvstd(ν_inner, marginals(μ), x) +end + +struct _TransportToMvStd{NU<:StdMeasure} <: Function end +(::_TransportToMvStd{NU})(μ, x) where {NU} = transport_to_mvstd(NU(), μ, x) + +function _marginals_to_mvstd(::NU, marginals_μ::Tuple, x::Tuple) where {NU<:StdMeasure} + _flatten_to_rv(map(_TransportToMvStd{NU}(), marginals_μ, x)) +end + +function _marginals_to_mvstd(::NU, marginals_μ, x) where {NU<:StdMeasure} + _flatten_to_rv(broadcast(_TransportToMvStd{NU}(), marginals_μ, x)) +end + + + +# Transport StdMeasure type to ProductMeasure, with rest: + +const _MaybeUnkownDOF = Union{IntegerLike,AbstractNoDOF} + +const _KnownDOFs = Union{Tuple{Vararg{IntegerLike,N}} where N, StaticVector{<:IntegerLike}} +const _MaybeUnkownKnownDOFs = Union{Tuple{Vararg{_MaybeUnkownDOF,N}} where N, StaticVector{<:_MaybeUnkownDOF}} + +function transport_from_mvstd_with_rest(ν::ProductMeasure, μ_inner::StdMeasure, x) + νs = marginals(ν) + dofs = map(fast_dof, νs) + return _marginals_from_mvstd_with_rest(νs, dofs, μ_inner, x) +end + +function _dof_access_firstidxs(dofs::Tuple{Vararg{IntegerLike,N}}, first_idx) where N + cumsum((first_idx, dofs[begin:end-1]...)) +end + +function _dof_access_firstidxs(dofs::AbstractVector{<:IntegerLike}, first_idx) + # ToDo: Improve imlementation (reduce memory allocations) + cumsum(vcat([eltype(dofs)(first_idx)], dofs[begin:end-1])) +end + +function _split_x_by_marginals_with_rest(dofs::Union{Tuple,AbstractVector}, x::AbstractVector{<:Real}) + x_idxs = maybestatic_eachindex(x) + first_idxs = _dof_access_firstidxs(dofs, maybestatic_first(x_idxs)) + xs = map((from, n) -> _get_or_view(x, from, from + n - one(n)), first_idxs, dofs) + x_rest = _get_or_view(x, first_idxs[end] + dofs[end], maybestatic_last(x_idxs)) + return xs, x_rest +end + +function _marginals_from_mvstd_with_rest(νs, dofs::_KnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) + xs, x_rest = _split_x_by_marginals_with_rest(dofs, x) + # ToDo: Is this ideal? + μs = map(n -> μ_inner^n, dofs) + ys = map(transport_def, νs, μs, xs) + return ys, x_rest +end + +function _marginals_from_mvstd_with_rest(νs, ::_MaybeUnkownKnownDOFs, μ_inner::StdMeasure, x::AbstractVector{<:Real}) + _marginals_from_mvstd_with_rest_nodof(νs, μ_inner, x) +end + +function _marginals_from_mvstd_with_rest_nodof(νs::Tuple{Vararg{AbstractMeasure,N}}, μ_inner::StdMeasure, x::AbstractVector{<:Real}) where N + # ToDo: Check for type stability, may need generated function + y1, x_rest = transport_from_mvstd_with_rest(νs[1], μ_inner, x) + y2_end, x_final_rest = _marginals_from_mvstd_with_rest_nodof(νs[2:end], μ_inner, x_rest) + return (y1, y2_end...), x_final_rest +end + +function _marginals_from_mvstd_with_rest_nodof(νs::AbstractVector{<:AbstractMeasure}, μ_inner::StdMeasure, x::AbstractVector{<:Real}) + # ToDo: Check for type stability, may need generated function + y1, x_rest = transport_from_mvstd_with_rest(νs[1], μ_inner, x) + ys = [y1] + for ν in νs[begin+1:end] + y_i, x_rest = _marginals_from_mvstd_with_rest_nodof(ν, μ_inner, x_rest) + ys = vcat(ys, y_i) + end + return ys, x_rest +end diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 26ba3948..e803896b 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -1,62 +1,111 @@ +# Canonical measure type nesting, outer to inner: +# +# WeightedMeasure, Dirac, PowerMeasure, ProductMeasure + + ############################################################################### # Half half(μ::AbstractMeasure) = Half(μ) ############################################################################### -# PointwiseProductMeasure +# PowerMeaure -function pointwiseproduct(μ::AbstractMeasure, ℓ::Likelihood) - T = Core.Compiler.return_type(ℓ.k, Tuple{gentype(μ)}) - return pointwiseproduct(T, μ, ℓ) -end +""" + powermeasure(μ, dims) + powermeasure(μ, axes) -function pointwiseproduct(::Type{T}, μ::AbstractMeasure, ℓ::Likelihood) where {T} - return PointwiseProductMeasure(μ, ℓ) -end +Constructs a power of a measure `μ`. -############################################################################### -# PowerMeaure +`powermeasure(μ, exponent)` is semantically equivalent to +`productmeasure(Fill(μ, exponent))`, but more efficient. +""" +function powermeasure end +export powermeasure -powermeasure(m::AbstractMeasure, ::Tuple{}) = m +@inline powermeasure(μ, exponent) = _generic_powermeasure_stage1(asmeasure(μ), _pm_axes(exponent)) -function powermeasure( - μ::WeightedMeasure, - dims::Tuple{<:AbstractArray,Vararg{AbstractArray}}, -) - k = mapreduce(length, *, dims) * μ.logweight - return weightedmeasure(k, μ.base^dims) +@inline _generic_powermeasure_stage1(μ::AbstractMeasure, ::Tuple{}) = μ + +@inline function _generic_powermeasure_stage1(μ::AbstractMeasure, exponent::Tuple) + _generic_powermeasure_stage2(μ, exponent) +end + +@inline _generic_powermeasure_stage2(μ::AbstractMeasure, exponent::Tuple) = PowerMeasure(μ, exponent) + +@inline function _generic_powermeasure_stage2(μ::Dirac, exponent::Tuple) + Dirac(maybestatic_fill(μ.x, exponent)) end -function powermeasure(μ::WeightedMeasure, dims::NonEmptyTuple) - k = prod(dims) * μ.logweight - return weightedmeasure(k, μ.base^dims) +@inline function _generic_powermeasure_stage2(μ::WeightedMeasure, exponent::Tuple) + ν = μ.base^exponent + k = maybestatic_length(ν) * μ.logweight + return weightedmeasure(k, ν) end + ############################################################################### # ProductMeasure -productmeasure(mar::FillArrays.Fill) = powermeasure(mar.value, mar.axes) +""" + productmeasure(μs) + +Constructs a product over a collection `μs` of measures. + +Examples: + +```julia +using MeasureBase, AffineMaps +productmeasure((StdNormal(), StdExponential())) +productmeasure(a = StdNormal(), b = StdExponential())) +productmeasure([pushfwd(Mul(scale), StdExponential()) for scale in 0.1:0.2:2]) +productmeasure((pushfwd(Mul(scale), StdExponential()) for scale in 0.1:0.2:2)) +""" +function productmeasure end +export productmeasure + +@inline productmeasure(mar) = _generic_productmeasure_impl(mar) + +@inline _generic_productmeasure_impl(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar)) + +@inline _generic_productmeasure_impl(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar) +_generic_productmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.x, mar)) +_generic_productmeasure_impl(mar::Tuple) = productmeasure(map(asmeasure, mar)) + +@inline _generic_productmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar) +_generic_productmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.x, mar)) +_generic_productmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, mar)) + +@inline _generic_productmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar) + +_generic_productmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar)) + +# TODO: We should be able to further optimize this +function _generic_productmeasure_impl(mar::AbstractArray{T}) where {T} + if Base.issingletontype(T) + first(mar) ^ size(mar) + else + ProductMeasure(asmeasure.(mar)) + end +end -function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} +@inline function _generic_productmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M} return powermeasure(mar.f.value, axes(mar.data)) end -productmeasure(mar::Base.Generator) = ProductMeasure(mar) -productmeasure(mar::AbstractArray) = ProductMeasure(mar) +@inline _generic_productmeasure_impl(mar::Base.Generator) = ProductMeasure(mar) # TODO: Make this static when its length is static -@inline function productmeasure( - mar::AbstractArray{WeightedMeasure{StaticFloat64{W},M}}, +@inline function _generic_productmeasure_impl( + mar::AbstractArray{<:WeightedMeasure{StaticFloat64{W},M}}, ) where {W,M} return weightedmeasure(W * length(mar), productmeasure(map(basemeasure, mar))) end -productmeasure(nt::NamedTuple) = ProductMeasure(nt) -productmeasure(tup::Tuple) = ProductMeasure(tup) +# ToDo: Remove or at least refactor this (ProductMeasure shouldn't take a kernel at it's argument). -productmeasure(f, param_maps, pars) = ProductMeasure(kernel(f, param_maps), pars) +productmeasure(f, param_maps, pars) = productmeasure(kernel(f, param_maps), pars) function productmeasure(k::ParameterizedTransitionKernel, pars) productmeasure(k.suff, k.param_maps, pars) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 803b404b..17d19497 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -18,6 +18,8 @@ function parent(::AbstractTransformedMeasure) end export PushforwardMeasure +# ToDo: Store FunctionWithInverse instead of f and finv in PushforwardMeasure? + """ struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward f :: F @@ -78,7 +80,8 @@ end return logdensity_def(ν.origin, x) end -insupport(ν::PushforwardMeasure, y) = insupport(ν.origin, ν.finv(y)) +# ToDo: How to handle this better? +insupport(ν::PushforwardMeasure, y) = NoFastInsupport{typeof(ν)}() function testvalue(::Type{T}, ν::PushforwardMeasure) where {T} ν.f(testvalue(T, parent(ν))) @@ -88,10 +91,17 @@ end pushfwd(ν.f, basemeasure(parent(ν)), NoVolCorr()) end -_pushfwd_dof(::Type{MU}, ::Type, dof) where {MU} = NoDOF{MU}() -_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof -@inline getdof(ν::MU) where {MU<:PushforwardMeasure} = getdof(ν.origin) +const _NonBijectivePushforward = Union{PushforwardMeasure{<:Any,<:NoInverse},PushforwardMeasure{<:NoInverse,<:Any},PushforwardMeasure{<:NoInverse,<:NoInverse}} + +@inline getdof(ν::PushforwardMeasure) = _pushfwd_dof(ν) +_pushfwd_dof(ν::PushforwardMeasure) = getdof(ν.origin) +_pushfwd_dof(ν::_NonBijectivePushforward) = NoDOF{typeof(ν)}() + +@inline fast_dof(ν::PushforwardMeasure) = _pushfwd_fastdof(ν) +_pushfwd_fastdof(ν::PushforwardMeasure) = fast_dof(ν.origin) +_pushfwd_fastdof(ν::_NonBijectivePushforward) = NoDOF{typeof(ν)}() + # Bypass `checked_arg`, would require potentially costly transformation: @inline checked_arg(::PushforwardMeasure, x) = x @@ -107,8 +117,6 @@ end ############################################################################### # pushfwd -export pushfwd - """ pushfwd(f, μ, volcorr = WithVolCorr()) @@ -119,28 +127,43 @@ measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the To manually specify an inverse, call `pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -function pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) +function pushfwd end +export pushfwd + +pushfwd(f) = Base.Fix1(pushfwd, f) + +pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pushfwd_impl(f, μ, volcorr) + +function _generic_pushfwd_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) PushforwardMeasure(f, inverse(f), μ, volcorr) end -function pushfwd(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) +function _generic_pushfwd_impl(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) _pushfwd_of_pushfwd(f, μ, μ.volcorr, volcorr) end # Either both WithVolCorr or both NoVolCorr, so we can merge them -function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, v::V) where {V} - pushfwd(fchain((μ.f, f)), μ.origin, v) +function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, volcorr::V) where {V} + pushfwd(f ∘ fchain(μ.f), μ.origin, volcorr) +end + +function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, volcorr) + PushforwardMeasure(f, inverse(f), μ, volcorr) end -function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, v) - PushforwardMeasure(f, inverse(f), μ, v) +function _generic_pushfwd_impl(f::TransportFunction{NU,MU}, μ::DensityMeasure{F,MU}, ::WithVolCorr) where {NU,MU,F} + if !(f.μ === μ.base || f.μ === μ.base) + throw(ArgumentError("pushfwd on DensityMeasure with TransportFunction of same source measure type as the density base requires base and source to be equal.")) + end + mintegrate(fchain(μ.f) ∘ inverse(f), f.ν) end + ############################################################################### # pullback """ - pullback(f, μ, volcorr = WithVolCorr()) + pullbck(f, μ, volcorr = WithVolCorr()) A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a map _from_ the support of a measure, a pullback requires a map _into_ the @@ -152,8 +175,40 @@ in terms of the inverse function; the "forward" function is not used at all. In some cases, we may be focusing on log-density (and not, for example, sampling). To manually specify an inverse, call -`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`. +`pullbck(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) - pushfwd(setinverse(inverse(f), f), μ, volcorr) + +function pullbck end +export pullbck + +pullbck(f) = Base.Fix1(pullbck, f) + +pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) = _generic_pullbck_impl(f, μ, volcorr) + +function _generic_pullbck_impl(f, μ, volcorr::TransformVolCorr = WithVolCorr()) + PushforwardMeasure(inverse(f), f, μ, volcorr) +end + +# TODO: Duplicated method - was this supposed to be `_generic_pullbck_impl`? +# function _generic_pushfwd_impl(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) +# _pullbck_of_pushfwd(f, μ, μ.volcorr, volcorr) +# end + +# Either both WithVolCorr or both NoVolCorr, so we can merge them +function _pullbck_of_pushfwd(f, μ::PushforwardMeasure, ::V, volcorr::V) where {V} + pullbck(fchain(μ.finv) ∘ f, μ.origin, volcorr) end + +function _pullbck_of_pushfwd(f, μ::PushforwardMeasure, _, volcorr) + PushforwardMeasure(inverse(f), f, μ, volcorr) +end + +function _generic_pullbck_impl(f::TransportFunction{NU,MU}, μ::DensityMeasure{F,NU}, ::WithVolCorr) where {NU,MU,F} + if !(f.ν === μ.base || f.ν === μ.base) + throw(ArgumentError("pushfwd on DensityMeasure with TransportFunction of same destination measure type as the density base requires base and destination to be equal.")) + end + mintegrate(fchain(μ.f) ∘ f, f.μ) +end + + +@deprecate pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) pullbck(f, μ, volcorr) diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index db239b50..124662b6 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -46,9 +46,6 @@ end Base.:*(m::AbstractMeasure, k::Real) = k * m -≪(::M, ::WeightedMeasure{R,M}) where {R,M} = true -≪(::WeightedMeasure{R,M}, ::M) where {R,M} = true - gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) diff --git a/src/density-core.jl b/src/density-core.jl index c8c861ee..bcd2e51e 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -9,6 +9,41 @@ export densityof export density_rel export density_def + +""" + localmeasure(m::AbstractMeasure, x)::AbstractMeasure + +Return a measure that behaves like `m` in the infinitesimal neighborhood +of `x` in respect to density calculation. + +Note that the resulting measure may not be well defined outside of the +infinitesimal neighborhood of `x`. + +For most measure types simply returns `m` itself. [`mbind`](@ref), +for example, generates measures for with `localmeasure(m, x)` depends +on `x`. +""" +localmeasure(m::AbstractMeasure, x) = m +export localmeasure + + +""" + MeasureBase.transportmeasure(μ::Bind, x)::AbstractMeasure + +Return a measure that behaves like `m` in the infinitesimal neighborhood +of `x` in respect to both transport and density calculation. + +Note that the resulting measure may not be well defined outside of the +infinitesimal neighborhood of `x`. + +For most measure types simply returns `m` itself. [`mbind`](@ref), +for example, generates measures for with `transportmeasure(m, x)` depends +on `x`. +""" +transportmeasure(m::AbstractMeasure, x) = m +export localmeasure + + """ logdensityof(m::AbstractMeasure, x) @@ -34,6 +69,7 @@ To compute a log-density relative to a specific base-measure, see end _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf)) +@inline _checksupport(::NoFastInsupport, result) = result import ChainRulesCore @inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result) @@ -56,10 +92,16 @@ This is "unsafe" because it does not check `insupport(m, x)`. See also `logdensityof`. """ @inline function unsafe_logdensityof(μ::M, x) where {M} + μ_local = localmeasure(μ, x) + # Extra dispatch boundary to reduce number of required specializations of implementation: + return _unsafe_logdensityof_local(μ_local, x) +end + +@inline function _unsafe_logdensityof_local(μ::M, x) where {M} ℓ_0 = logdensity_def(μ, x) b_0 = μ Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number - b_{i} = basemeasure(b_{i - 1}, x) + b_{i} = basemeasure(b_{i - 1}) # The below makes the evaluated code shorter, but screws up Zygote # if b_{i} isa typeof(b_{i - 1}) @@ -72,6 +114,17 @@ See also `logdensityof`. return ℓ_10 end + +""" + logdensity_type(m::AbstractMeasure}, ::Type{T}) where T + +Compute the return type of `logdensity_of(m, ::T)`. +""" +function logdensity_type(m::M,T) where {M<:AbstractMeasure} + Core.Compiler.return_type(logdensity_def, Tuple{M, T}) +end + + """ logdensity_rel(m1, m2, x) @@ -81,20 +134,42 @@ known to be in the support of both, it can be more efficient to call `unsafe_logdensity_rel`. """ @inline function logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} + inμ = insupport(μ, x) + inν = insupport(ν, x) + return _logdensity_rel_impl(μ, ν, x, inμ, inν) +end + + +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X} T = unstatic( promote_type( - return_type(logdensity_def, (μ, x)), - return_type(logdensity_def, (ν, x)), + logdensity_type(μ, X), + logdensity_type(ν, X), ), ) - inμ = insupport(μ, x) - inν = insupport(ν, x) + istrue(inμ) || return convert(T, ifelse(inν, -Inf, NaN)) istrue(inν) || return convert(T, Inf) return unsafe_logdensity_rel(μ, ν, x) end + +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X} + unsafe_logdensity_rel(μ, ν, x) +end + +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X} + logd = unsafe_logdensity_rel(μ, ν, x) + return istrue(inμ) ? logd : logd * oftype(logd, -Inf) +end + +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X} + logd = unsafe_logdensity_rel(μ, ν, x) + return istrue(inν) ? logd : logd * oftype(logd, +Inf) +end + + """ unsafe_logdensity_rel(m1, m2, x) @@ -104,6 +179,13 @@ known to be in the support of both `m1` and `m2`. See also `logdensity_rel`. """ @inline function unsafe_logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} + μ_local = localmeasure(μ, x) + ν_local = localmeasure(ν, x) + # Extra dispatch boundary to reduce number of required specializations of implementation: + return _unsafe_logdensity_rel_local(μ_local, ν_local, x) +end + +@inline function _unsafe_logdensity_rel_local(μ::M, ν::N, x::X) where {M,N,X} if static_hasmethod(logdensity_def, Tuple{M,N,X}) return logdensity_def(μ, ν, x) end diff --git a/src/density.jl b/src/density.jl index 4862dcb1..543e33d6 100644 --- a/src/density.jl +++ b/src/density.jl @@ -20,8 +20,7 @@ For measures `μ` and `ν`, `Density(μ,ν)` represents the _density function_ `dμ/dν`, also called the _Radom-Nikodym derivative_: https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem#Radon%E2%80%93Nikodym_derivative -Instead of calling this directly, users should call `density_rel(μ, ν)` or -its abbreviated form, `𝒹(μ,ν)`. +Instead of calling this directly, users should call `density_rel(μ, ν)`. """ struct Density{M,B} <: AbstractDensity μ::M @@ -32,16 +31,6 @@ Base.:∘(::typeof(log), d::Density) = logdensity_rel(d.μ, d.base) Base.log(d::Density) = log ∘ d -export 𝒹 - -""" - 𝒹(μ, base) - -Compute the density (Radom-Nikodym derivative) of μ with respect to `base`. This -is a shorthand form for `density_rel(μ, base)`. -""" -𝒹(μ, base) = density_rel(μ, base) - density_rel(μ, base) = Density(μ, base) (f::Density)(x) = density_rel(f.μ, f.base, x) @@ -73,16 +62,6 @@ Base.:∘(::typeof(exp), d::LogDensity) = density_rel(d.μ, d.base) Base.exp(d::LogDensity) = exp ∘ d -export log𝒹 - -""" - log𝒹(μ, base) - -Compute the log-density (Radom-Nikodym derivative) of μ with respect to `base`. -This is a shorthand form for `logdensity_rel(μ, base)` -""" -log𝒹(μ, base) = logdensity_rel(μ, base) - logdensity_rel(μ, base) = LogDensity(μ, base) (f::LogDensity)(x) = logdensity_rel(f.μ, f.base, x) @@ -98,12 +77,13 @@ DensityInterface.funcdensity(d::LogDensity) = throw(MethodError(funcdensity, (d, base :: B end -A `DensityMeasure` is a measure defined by a density or log-density with respect -to some other "base" measure. +A `DensityMeasure` is a measure defined by a density or log-density with +respect to some other "base" measure. -Users should not call `DensityMeasure` directly, but should instead call `∫(f, -base)` (if `f` is a density function or `DensityInterface.IsDensity` object) or -`∫exp(f, base)` (if `f` is a log-density function). +Users should not instantiate `DensityMeasure` directly, but should instead +call `mintegral(f, base)` (if `f` is a density function or +`DensityInterface.IsDensity` object) or `mintegral_exp(f, base)` (if `f` +is a log-density function). """ struct DensityMeasure{F,B} <: AbstractMeasure f::F @@ -116,60 +96,91 @@ struct DensityMeasure{F,B} <: AbstractMeasure end @inline function insupport(d::DensityMeasure, x) - insupport(d.base, x) == true && isfinite(logdensityof(getfield(d, :f), x)) + # ToDo: should not evaluate f + insupport(d.base, x) != false && isfinite(logdensityof(getfield(d, :f), x)) end function Pretty.tile(μ::DensityMeasure{F,B}) where {F,B} - result = Pretty.literal("DensityMeasure ∫(") + result = Pretty.literal("mintegrate(") result *= Pretty.pair_layout(Pretty.tile(μ.f), Pretty.tile(μ.base); sep = ", ") result *= Pretty.literal(")") end -export ∫ +basemeasure(μ::DensityMeasure) = μ.base -""" - ∫(f, base::AbstractMeasure) +logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) -Define a new measure in terms of a density `f` over some measure `base`. -""" -∫(f, base) = _densitymeasure(f, base, DensityKind(f)) +density_def(μ::DensityMeasure, x) = densityof(μ.f, x) -_densitymeasure(f, base, ::IsDensity) = DensityMeasure(f, base) -function _densitymeasure(f, base, ::HasDensity) - @error "`∫(f, base)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`." -end -_densitymeasure(f, base, ::NoDensity) = DensityMeasure(funcdensity(f), base) +localmeasure(μ::DensityMeasure, x) = DensityMeasure(μ.f, localmeasure(μ.base, x)) -export ∫exp +@doc raw""" + mintegrate(f, μ::AbstractMeasure)::AbstractMeasure -""" - ∫exp(f, base::AbstractMeasure) +Returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. + +`ν = mintegrate(f, μ)` generates a measure `ν` that has the mathematical +interpretation -Define a new measure in terms of a log-density `f` over some measure `base`. +math``` +\nu(A) = \int_A f(a) \, \rm{d}\mu(a) +``` """ -∫exp(f, base) = _logdensitymeasure(f, base, DensityKind(f)) +function mintegrate end +export mintegrate -function _logdensitymeasure(f, base, ::IsDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == IsDensity()`. Use `∫(f, base)` instead." -end -function _logdensitymeasure(f, base, ::HasDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == HasDensity()`." -end -_logdensitymeasure(f, base, ::NoDensity) = DensityMeasure(logfuncdensity(f), base) +mintegrate(f, μ::AbstractMeasure) = _mintegrate_impl(f, μ, DensityKind(f)) -basemeasure(μ::DensityMeasure) = μ.base +_mintegrate_impl(f, μ, ::IsDensity) = DensityMeasure(f, μ) +function _mintegrate_impl(f, μ, ::HasDensity) + throw( + ArgumentError( + "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.", + ), + ) +end +_mintegrate_impl(f, μ, ::NoDensity) = DensityMeasure(funcdensity(f), μ) -logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) +@doc raw""" + mintegrate_exp(log_f, μ::AbstractMeasure) -density_def(μ::DensityMeasure, x) = densityof(μ.f, x) +Given a function `log_f` that semantically represents the log of a function +`f`, `mintegrate` returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. -""" - rebase(μ, ν) +`ν = mintegrate_exp(log_f, μ)` generates a measure `ν` that has the +mathematical interpretation -Express `μ` in terms of a density over `ν`. Satisfies +math``` +\nu(A) = \int_A e^{log(f(a))} \, \rm{d}\mu(a) = \int_A f(a) \, \rm{d}\mu(a) ``` -basemeasure(rebase(μ, ν)) == ν -density(rebase(μ, ν)) == 𝒹(μ,ν) -``` + +Note that `exp(log_f(...))` is usually not run explicitly, calculations that +involve the resulting measure are typically performed in log-space, +internally. """ -rebase(μ, ν) = ∫(𝒹(μ, ν), ν) +function mintegrate_exp end +export mintegrate_exp + +function mintegrate_exp(log_f, μ::AbstractMeasure) + _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) +end + +function _mintegrate_exp_impl(log_f, μ, ::IsDensity) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegrate(log_f, μ)` instead.", + ), + ) +end +function _mintegrate_exp_impl(log_f, μ, ::HasDensity) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.", + ), + ) +end +_mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) diff --git a/src/domains.jl b/src/domains.jl index e03f753c..9c1f1a21 100644 --- a/src/domains.jl +++ b/src/domains.jl @@ -106,7 +106,7 @@ function tangentat( one(T) - Statistics.corm(g1, zero(T), g2, zero(T)) < tol end -function zeroset(::CodimOne)::ZeroSet end +function zeroset(::CodimOne) end ########################################################### # Simplex @@ -116,7 +116,7 @@ struct Simplex <: CodimOne end function zeroset(::Simplex) f(x::AbstractArray{T}) where {T} = sum(x) - one(T) - ∇f(x::AbstractArray{T}) where {T} = fill_with(one(T), size(x)) + ∇f(x::AbstractArray{T}) where {T} = maybestatic_fill(one(T), size(x)) ZeroSet(f, ∇f) end diff --git a/src/fixedrng.jl b/src/fixedrng.jl deleted file mode 100644 index 232b0891..00000000 --- a/src/fixedrng.jl +++ /dev/null @@ -1,19 +0,0 @@ -export FixedRNG -struct FixedRNG <: AbstractRNG end - -Base.rand(::FixedRNG) = one(Float64) / 2 -Random.randn(::FixedRNG) = zero(Float64) -Random.randexp(::FixedRNG) = one(Float64) - -Base.rand(::FixedRNG, ::Type{T}) where {T<:Real} = one(T) / 2 -Random.randn(::FixedRNG, ::Type{T}) where {T<:Real} = zero(T) -Random.randexp(::FixedRNG, ::Type{T}) where {T<:Real} = one(T) - -# We need concrete type parameters to avoid amiguity for these cases -for T in [Float16, Float32, Float64] - @eval begin - Base.rand(::FixedRNG, ::Type{$T}) = one($T) / 2 - Random.randn(::FixedRNG, ::Type{$T}) = zero($T) - Random.randexp(::FixedRNG, ::Type{$T}) = one($T) - end -end diff --git a/src/getdof.jl b/src/getdof.jl index 4496b7f2..fa60b8c0 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -1,11 +1,30 @@ """ - MeasureBase.NoDOF{MU} + abstract type MeasureBase.AbstractNoDOF{MU} + +Abstract supertype for [`NoDOF`](@ref) and [`NoFastDOF`](@ref). +""" +abstract type AbstractNoDOF{MU} end + +Base.:+(nodof::AbstractNoDOF) = nodof +Base.:+(::IntegerLike, nodof::AbstractNoDOF) = nodof +Base.:+(nodof::AbstractNoDOF, ::IntegerLike) = nodof +Base.:+(nodof::AbstractNoDOF, ::AbstractNoDOF) = nodof + +Base.:*(nodof::AbstractNoDOF) = nodof +Base.:*(::IntegerLike, nodof::AbstractNoDOF) = nodof +Base.:*(nodof::AbstractNoDOF, ::IntegerLike) = nodof +Base.:*(nodof::AbstractNoDOF, ::AbstractNoDOF) = nodof + + +""" + MeasureBase.NoDOF{MU} <: AbstractNoDOF{MU} Indicates that there is no way to compute degrees of freedom of a measure of type `MU` with the given information, e.g. because the DOF are not a global property of the measure. """ -struct NoDOF{MU} end +struct NoDOF{MU} <: AbstractNoDOF{MU} end + """ getdof(μ) @@ -20,24 +39,89 @@ is `n - 1`. Also see [`check_dof`](@ref). """ function getdof end +export getdof # Prevent infinite recursion: -@inline _default_getdof(::Type{MU}, ::MU) where {MU} = NoDOF{MU} +@inline _default_getdof(::Type{MU}, ::MU) where {MU} = NoDOF{MU}() @inline _default_getdof(::Type{MU}, mu_base) where {MU} = getdof(mu_base) @inline getdof(μ::MU) where {MU} = _default_getdof(MU, basemeasure(μ)) + +""" + MeasureBase.NoFastDOF{MU} <: AbstractNoDOF{MU} + +Indicates that there is no way to compute degrees of freedom of a measure +of type `MU` with the given information, e.g. because the DOF are not +a global property of the measure. +""" +struct NoFastDOF{MU} <: AbstractNoDOF{MU} end + + +""" + fast_dof(μ::MU) + +Returns the effective number of degrees of freedom of variates of +measure `μ`, if it can be computed efficiently, otherwise +returns [`NoFastDOF{MU}()`](@ref). + +Defaults to `getdof(μ)` and should be specialized for measures for +wich DOF can't be computed instantly. + +The effective NDOF my differ from the length of the variates. For example, +the effective NDOF for a Dirichlet distribution with variates of length `n` +is `n - 1`. + +Also see [`check_dof`](@ref). +""" +function fast_dof end +export fast_dof + +fast_dof(μ) = getdof(μ) + + +""" + MeasureBase.some_dof(μ::AbstractMeasure) + +Get the DOF at some unspecified point of measure `μ`. + +Use with caution! + +In general, use [`getdof(μ)`](@ref) instead. `some_dof` is useful +for measures are expected to have a constant DOF of their whole +space but for which there is no way to compute it (or prove that +the DOF is constant of the measurable space). +""" +function some_dof end + +function some_dof(μ) + m = asmeasure(μ) + _try_direct_dof(m, getdof(m)) +end + +_try_direct_dof(::AbstractMeasure, dof::IntegerLike) = dof +_try_direct_dof(μ::AbstractMeasure, ::AbstractNoDOF) = _try_local_dof(μ::AbstractMeasure, some_dof(_some_localmeasure(μ))) +_try_local_dof(::AbstractMeasure, dof::IntegerLike) = dof +_try_local_dof(μ::AbstractMeasure, ::AbstractNoDOF) = throw(ArgumentError("Can't determine DOF for measure of type $(nameof(typeof(μ)))")) + +_some_localmeasure(μ::AbstractMeasure) = localmeasure(μ, testvalue(μ)) + + """ MeasureBase.check_dof(ν, μ)::Nothing Check if `ν` and `μ` have the same effective number of degrees of freedom -according to [`MeasureBase.getdof`](@ref). +according to [`MeasureBase.fast_dof`](@ref). """ function check_dof end function check_dof(ν, μ) - n_ν = getdof(ν) - n_μ = getdof(μ) + n_ν = fast_dof(ν) + n_μ = fast_dof(μ) + # TODO: How to handle this better if DOF is unclear e.g. for HierarchicalMeasures? + if n_ν isa AbstractNoDOF || n_μ isa AbstractNoDOF + return true + end if n_ν != n_μ throw( ArgumentError( @@ -51,6 +135,7 @@ end _check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent() ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback + """ MeasureBase.NoArgCheck{MU,T} diff --git a/src/insupport.jl b/src/insupport.jl index 5184917d..4da0adcd 100644 --- a/src/insupport.jl +++ b/src/insupport.jl @@ -1,20 +1,33 @@ +""" + MeasureBase.NoFastInsupport{MU} + +Indicates that there is no fast way to compute if a point lies within the +support of measures of type `MU` +""" +struct NoFastInsupport{MU} end + + """ inssupport(m, x) insupport(m) -`insupport(m,x)` computes whether `x` is in the support of `m`. +`insupport(m,x)` computes whether `x` is in the support of `m` and +returns either a `Bool` or an instance of [`NoFastInsupport`](@ref). `insupport(m)` returns a function, and satisfies - -insupport(m)(x) == insupport(m, x) +`insupport(m)(x) == insupport(m, x)`` """ function insupport end + """ MeasureBase.require_insupport(μ, x)::Nothing Checks if `x` is in the support of distribution/measure `μ`, throws an `ArgumentError` if not. + +Will not throw an exception if `insupport` returns an instance of +[`NoFastInsupport`](@ref). """ function require_insupport end @@ -24,8 +37,11 @@ function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) end function require_insupport(μ, x) - if !insupport(μ, x) - throw(ArgumentError("x is not within the support of μ")) + ins = insupport(μ, x) + if !(ins isa NoFastInsupport) + if !ins + throw(ArgumentError("x is not within the support of μ")) + end end return nothing end diff --git a/src/interface.jl b/src/interface.jl index 18080ac7..f66c9893 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -64,7 +64,7 @@ function test_interface(μ::M) where {M} # testvalue, logdensityof x = @inferred testvalue(Float64, μ) - β = @inferred basemeasure(μ, x) + β = @inferred basemeasure(μ) ℓμ = @inferred logdensityof(μ, x) ℓβ = @inferred logdensityof(β, x) diff --git a/src/mass-interface.jl b/src/mass-interface.jl index 7b0518f9..09a83b2c 100644 --- a/src/mass-interface.jl +++ b/src/mass-interface.jl @@ -22,7 +22,10 @@ for T in (:UnknownFiniteMass, :UnknownMass) @eval begin Base.:+(::$T, ::$T) = $T() Base.:*(::$T, ::$T) = $T() - Base.:^(::$T, k::Number) = isfinite(k) ? $T() : UnknownMass() + Base.:^(::$T, k::Real) = isfinite(k) ? $T() : UnknownMass() + # Disambiguation: + Base.:^(::$T, k::Integer) = isfinite(k) ? $T() : UnknownMass() + Base.:^(::$T, k::Rational) = isfinite(k) ? $T() : UnknownMass() end end @@ -65,7 +68,7 @@ finite, or we may know nothing at all about it. For these cases, it will return `UnknownFiniteMass` or `UnknownMass`, respectively. When no `massof` method exists, it defaults to `UnknownMass`. """ -massof(m::AbstractMeasure) = UnknownMass(m) +massof(::AbstractMeasure) = UnknownMass() struct NormalizedMeasure{P,M} <: AbstractMeasure parent::P @@ -104,9 +107,6 @@ isnormalized(x, p::Real = 2) = isone(norm(x, p)) isone(::AbstractUnknownMass) = false -function massof(m, s) - _massof(m, s, rootmeasure(m)) -end """ (m::AbstractMeasure)(s) @@ -116,4 +116,11 @@ in this way, users should add the corresponding `massof` method. """ (m::AbstractMeasure)(s) = massof(m, s) -massof(μ, a_b::AbstractInterval) = smf(μ, rightendpoint(a_b)) - smf(μ, leftendpoint(a_b)) +function massof(m, s) + _default_massof_impl(m, s, rootmeasure(m)) +end + +# # ToDo: Use smf if defined +#function _default_massof_impl(μ, a_b::AbstractInterval, ::LebesgueBase) +# smf(μ, rightendpoint(a_b)) - smf(μ, leftendpoint(a_b)) +#end diff --git a/src/measure_operators.jl b/src/measure_operators.jl new file mode 100644 index 00000000..90f973f0 --- /dev/null +++ b/src/measure_operators.jl @@ -0,0 +1,131 @@ +""" + module MeasureOperators + +Defines the following operators for measures: + +* `f ⋄ μ == pushfwd(f, μ)` + +* `μ ⊙ f == inverse(f) ⋄ μ` +""" +module MeasureOperators + +using MeasureBase: AbstractMeasure +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using InverseFunctions: inverse +using Reexport: @reexport + +@doc raw""" + ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) + +The `\\diamond` operator denotes a pushforward operation: `ν = f ⋄ μ` +generates a +[pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure). + +A common mathematical notation for a pushforward is ``f_*μ``, but as +there is no "subscript-star" operator in Julia, we use `⋄`. + +See [`pushfwd(f, μ)`](@ref) for details. + +Also see [`ν ⊙ f`](@ref), the pullback operator. +""" +⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) +export ⋄ + +@doc raw""" + ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) + +The `\\odot` operator denotes a pullback operation. + +See also [`pullbck(ν, f)`](@ref) for details. Note that `pullbck` takes it's +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. A pullback is mathematically the +precomposition of a measure `μ`` with the function `f` applied to sets. so +`⊙` takes the measure as the first and the function as the second argument, +as common in mathematical notation for precomposition. + +A common mathematical notation for pullback in measure theory is +``f \circ μ``, but as `∘` is used for function composition in Julia and as +`f` semantically acts point-wise on sets, we use `⊙`. + +Also see [f ⋄ μ](@ref), the pushforward operator. +""" +⊙(ν::AbstractMeasure, f) = pullbck(f, ν) +export ⊙ + +""" + μ ▷ k = mbind(k, μ) + +The `\\triangleright` operator denotes a measure monadic bind operation. + +A common operator choice for a monadic bind operator is `>>=` (e.g. in +the Haskell programming language), but this has a different meaning in +Julia and there is no close equivalent, so we use `▷`. + +See [`mbind(k, μ)`](@ref) for details. Note that `mbind` takes its +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. `▷`, on the other hand, takes +its arguments in the order common for monadic binds in functional +programming (like the Haskell `>>=` operator) and mathematics. +""" +▷(μ::AbstractMeasure, k) = mbind(k, μ) +export ▷ + +# ToDo: Use `⨂` instead of `⊗` for better readability? +""" + ⊗(μs::AbstractMeasure...) = productmeasure(μs) + +`⊗` is an operator for building product measures. + +See [`productmeasure(μs)`](@ref) for details. +""" +⊗(μs::AbstractMeasure...) = productmeasure(μs) +export ⊗ + +""" + ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) + +Denotes an indefinite integral of the function `f` with respect to the +measure `μ`. + +See [`mintegrate(f, μ)`](@ref) for details. +""" +∫(f, μ::AbstractMeasure) = mintegrate(f, μ) +export ∫ + +""" + ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) + +Generates a new measure that is the indefinite integral of `exp` of `f` +with respect to the measure `μ`. + +See [`mintegrate_exp(f, μ)`](@ref) for details. +""" +∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) +export ∫exp + +""" + 𝒹(ν, μ) = density_rel(ν, μ) + +Compute the density, i.e. the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`density_rel(ν, μ)`}(@ref). +""" +𝒹(ν, μ::AbstractMeasure) = density_rel(ν, μ) +export 𝒹 + +""" + log𝒹(ν, μ) = logdensity_rel(ν, μ) + +Compute the log-density, i.e. the logarithm of the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`logdensity_rel(ν, μ)`}(@ref). +""" +log𝒹(ν, μ::AbstractMeasure) = logdensity_rel(ν, μ) +export log𝒹 + +end # module MeasureOperators diff --git a/src/parameterized.jl b/src/parameterized.jl index 78e43995..8b1c8c88 100644 --- a/src/parameterized.jl +++ b/src/parameterized.jl @@ -127,14 +127,3 @@ params(::Type{PM}) where {N,PM<:ParameterizedMeasure{N}} = N function paramnames(μ, constraints::NamedTuple{N}) where {N} tuple((k for k in paramnames(μ) if k ∉ N)...) end - -############################################################################### -# kernelfactor - -function kernelfactor(::Type{P}) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end - -function kernelfactor(::P) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index 8c42766f..2d2ed8dd 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -26,12 +26,12 @@ end massof(::LebesgueBase) = static(Inf) -function _massof(m, s::Interval, ::LebesgueBase) +function _default_massof_impl(m, s::AbstractInterval, ::LebesgueBase) mass = massof(m) nu = mass * StdUniform() f = transport_to(nu, m) - a = f(minimum(s)) - b = f(maximum(s)) + a = f(leftendpoint(s)) + b = f(rightendpoint(s)) return mass * abs(b - a) end diff --git a/src/proxies.jl b/src/proxies.jl index 95aed270..ffbf286f 100644 --- a/src/proxies.jl +++ b/src/proxies.jl @@ -15,15 +15,24 @@ macro useproxy(M) M = esc(M) quote @inline $MeasureBase.logdensity_def(μ::$M, x) = logdensity_def(proxy(μ), x) + @inline $MeasureBase.unsafe_logdensityof(μ::$M, x) = unsafe_logdensityof(proxy(μ), x) @inline $MeasureBase.basemeasure(μ::$M) = basemeasure(proxy(μ)) - @inline $MeasureBase.basemeasure_depth(μ::$M) = basemeasure_depth(proxy(μ)) + @inline $MeasureBase.rootmeasure(μ::$M) = rootmeasure(proxy(μ)) + @inline $MeasureBase.insupport(μ::$M) = insupport(proxy(μ)) + + @inline $MeasureBase.getdof(μ::$M) = getdof(proxy(μ)) + @inline $MeasureBase.fast_dof(μ::$M) = fast_dof(proxy(μ)) + @inline $MeasureBase.transport_origin(μ::$M) = transport_origin(proxy(μ)) @inline $MeasureBase.to_origin(μ::$M, y) = to_origin(proxy(μ), y) @inline $MeasureBase.from_origin(μ::$M, x) = from_origin(proxy(μ), x) + @inline $MeasureBase.localmeasure(μ::$M, x) = localmeasure(proxy(μ), x) + @inline $MeasureBase.transportmeasure(μ::$M, x) = transportmeasure(proxy(μ), x) + @inline $MeasureBase.massof(μ::$M) = massof(proxy(μ)) @inline $MeasureBase.massof(μ::$M, s) = massof(proxy(μ), s) diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index a02c5765..c1aaa88c 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -1,8 +1,17 @@ +""" + StdExponential <: StdMeasure + +Represents the standard (rate of one) +[exponential](https://en.wikipedia.org/wiki/Exponential_distribution) probability measure. + +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" struct StdExponential <: StdMeasure end export StdExponential -insupport(d::StdExponential, x) = x ≥ zero(x) +insupport(::StdExponential, x) = x ≥ zero(x) @inline logdensity_def(::StdExponential, x) = -x @inline basemeasure(::StdExponential) = LebesgueBase() diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 0d502ec6..705c153c 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -1,8 +1,16 @@ -struct StdLogistic <: StdMeasure end +""" + StdLogistic <: StdMeasure + +Represents the standard (centered, scale of one) +[logistic](https://en.wikipedia.org/wiki/Logistic_distribution) probability measure. +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" +struct StdLogistic <: StdMeasure end export StdLogistic -@inline insupport(d::StdLogistic, x) = true +@inline insupport(::StdLogistic, x) = true @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u)) @inline basemeasure(::StdLogistic) = LebesgueBase() diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 833f280e..05843ede 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -1,123 +1,39 @@ -abstract type StdMeasure <: AbstractMeasure end - -StdMeasure(::typeof(rand)) = StdUniform() -StdMeasure(::typeof(randexp)) = StdExponential() -StdMeasure(::typeof(randn)) = StdNormal() - -@inline check_dof(::StdMeasure, ::StdMeasure) = nothing - -@inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x - -function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) - return transport_def(ν, μ.parent, only(x)) -end - -function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) - return fill_with(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)) -end - -function transport_def( - ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, - μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, - x, -) - return transport_to(ν.parent, μ.parent).(x) -end - -function transport_def( - ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, - μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, - x, -) where {N,M} - return reshape(transport_to(ν.parent, μ.parent).(x), map(length, ν.axes)...) -end +""" + abstract type MeasureBase.StdMeasure -# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}): +Abstract supertype for standard measures. -_std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M() -_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof -_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) +Standard measures must be singleton types that represent common fundamental +measures such as [`StdUniform`](@ref), [`StdExponential`](@ref), +[`StdNormal`](@ref) and [`StdLogistic`](@ref). -function transport_to(::Type{NU}, μ) where {NU<:StdMeasure} - transport_to(_std_measure_for(NU, μ), μ) -end +A standard measure ([`StdUniform`](@ref), [`StdExponential`](@ref) and +[`StdNormal`](@ref)) is defined for every common Julia random number +generation function: -function transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} - transport_to(ν, _std_measure_for(MU, ν)) -end +``` +StdMeasure(rand) == StdUniform() +StdMeasure(randexp) == StdExponential() +StdMeasure(randn) == StdNormal() +``` -# Transform between standard measures and Dirac: +[`StdLogistic`](@ref) has no associated random number generation function. -@inline transport_def(ν::Dirac, ::PowerMeasure{<:StdMeasure}, ::Any) = ν.x - -@inline function transport_def(ν::PowerMeasure{<:StdMeasure}, ::Dirac, ::Any) - Zeros{Bool}(map(_ -> 0, ν.axes)) -end - -# Helpers for product transforms and similar: - -struct _TransportToStd{NU<:StdMeasure} <: Function end -_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) - -struct _TransportFromStd{MU<:StdMeasure} <: Function end -_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) - -function _tuple_transport_def( - ν::PowerMeasure{NU}, - μs::Tuple, - xs::Tuple, -) where {NU<:StdMeasure} - reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) -end - -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:Tuple}, - x, -) where {NU<:StdMeasure} - _tuple_transport_def(ν, marginals(μ), x) -end +All standard measures must be normalized, i.e. [`massof`](@ref) always +returns one. +""" +abstract type StdMeasure <: AbstractMeasure end -function transport_def( - ν::PowerMeasure{NU}, - μ::ProductMeasure{<:NamedTuple{names}}, - x, -) where {NU<:StdMeasure,names} - _tuple_transport_def(ν, values(marginals(μ)), values(x)) -end +@inline massof(::StdMeasure) = static(true) +@inline getdof(::StdMeasure) = static(1) -@inline _offset_cumsum(s, x, y, rest...) = (s, _offset_cumsum(s + x, y, rest...)...) -@inline _offset_cumsum(s, x) = (s,) -@inline _offset_cumsum(s) = () +StdMeasure(::typeof(rand)) = StdUniform() +StdMeasure(::typeof(randexp)) = StdExponential() +StdMeasure(::typeof(randn)) = StdNormal() -function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike) - N = map(getdof, μs) - offs = _offset_cumsum(startidx, N...) - map((o, n) -> o:o+n-1, offs, N) -end +@inline check_dof(::StdMeasure, ::StdMeasure) = nothing -function _tuple_transport_def( - νs::Tuple, - μ::PowerMeasure{MU}, - x::AbstractArray{<:Real}, -) where {MU<:StdMeasure} - vrs = _stdvar_viewranges(νs, firstindex(x)) - xs = map(r -> view(x, r), vrs) - map(_TransportFromStd{MU}, νs, xs) -end -function transport_def( - ν::ProductMeasure{<:Tuple}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure} - _tuple_transport_def(marginals(ν), μ, x) -end +# Transport between two equal standard measures: -function transport_def( - ν::ProductMeasure{<:NamedTuple{names}}, - μ::PowerMeasure{MU}, - x, -) where {MU<:StdMeasure,names} - NamedTuple{names}(_tuple_transport_def(values(marginals(ν)), μ, x)) -end +@inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x diff --git a/src/standard/stdnormal.jl b/src/standard/stdnormal.jl index dc9cac74..b083606b 100644 --- a/src/standard/stdnormal.jl +++ b/src/standard/stdnormal.jl @@ -1,11 +1,19 @@ using SpecialFunctions: erfc, erfcinv using IrrationalConstants: invsqrt2 -struct StdNormal <: StdMeasure end +""" + StdNormal <: StdMeasure + +Represents the standard (mean of zero, variance of one) +[normal](https://en.wikipedia.org/wiki/Normal_distribution) probability measure. +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" +struct StdNormal <: StdMeasure end export StdNormal -@inline insupport(d::StdNormal, x) = true +@inline insupport(::StdNormal, x) = true @inline logdensity_def(::StdNormal, x) = -x^2 / 2 @inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), LebesgueBase()) diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 8817561e..d443ce2e 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -1,8 +1,18 @@ -struct StdUniform <: StdMeasure end +""" + StdUniform <: StdMeasure + +Represents the standard +[uniform](https://en.wikipedia.org/wiki/Continuous_uniform_distribution) +probability measure (from zero to one). It is the +same as the Lebesgue measure restricted to the unit interval. +See [`StdMeasure`](@ref) for the semantics of standard measures in the +context of MeasureBase. +""" +struct StdUniform <: StdMeasure end export StdUniform -insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) +insupport(::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) @inline basemeasure(::StdUniform) = LebesgueBase() diff --git a/src/static.jl b/src/static.jl index b723d043..3f77cf3a 100644 --- a/src/static.jl +++ b/src/static.jl @@ -5,6 +5,31 @@ Equivalent to `Union{Integer,Static.StaticInteger}`. """ const IntegerLike = Union{Integer,Static.StaticInteger} + +""" + const UnitRangeFromOne + +Alias for unit ranges that start at one. +""" +const UnitRangeFromOne = Union{Base.OneTo, Static.OptionallyStaticUnitRange, StaticArrays.SOneTo} + + +""" + const StaticOneTo{N} + +A static unit range from one to N. +""" +const StaticOneTo{N} = Union{Static.OptionallyStaticUnitRange{StaticInt{1},StaticInt{N}}, StaticArrays.SOneTo{N}} + + +""" + const StaticUnitRange + +A static unit range. +""" +const StaticUnitRange = Union{Static.OptionallyStaticUnitRange{<:StaticInt,<:StaticInt}, StaticArrays.SOneTo} + + """ MeasureBase.one_to(n::IntegerLike) @@ -18,22 +43,32 @@ on the type of `n`. _dynamic(x::Number) = dynamic(x) _dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N) -_dynamic(r::AbstractUnitRange) = minimum(r):maximum(r) +_dynamic(r::Base.OneTo) = Base.OneTo(dynamic(r.stop)) + +function _dynamic(r::AbstractUnitRange) + if isempty(r) + 1:0 + else + minimum(r):maximum(r) + end +end """ - MeasureBase.fill_with(x, sz::NTuple{N,<:IntegerLike}) where N + MeasureBase.maybestatic_fill(x, sz::NTuple{N,<:IntegerLike}) where N Creates an array of size `sz` filled with `x`. Returns an instance of `FillArrays.Fill`. """ -function fill_with end +function maybestatic_fill end -@inline function fill_with(x::T, sz::Tuple{Vararg{IntegerLike,N}}) where {T,N} - fill_with(x, map(one_to, sz)) +@inline maybestatic_fill(x::T, ::Tuple{}) where T = FillArrays.Fill(x) + +@inline function maybestatic_fill(x::T, sz::Tuple{Vararg{IntegerLike,N}}) where {T,N} + maybestatic_fill(x, map(one_to, sz)) end -@inline function fill_with(x::T, axs::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N} +@inline function maybestatic_fill(x::T, axs::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N} # While `FillArrays.Fill` (mostly?) works with axes that are static unit # ranges, some operations that automatic differentiation requires do fail # on such instances of `Fill` (e.g. `reshape` from dynamic to static size). @@ -42,6 +77,42 @@ end FillArrays.Fill(x, dyn_axs) end +@inline function maybestatic_fill(x::T, axs::Tuple{Vararg{StaticOneTo}}) where T + fill(x, staticarray_type(T, map(maybestatic_length, axs))) +end + +@inline function maybestatic_fill(x::T, sz::Tuple{Vararg{StaticInteger}}) where T + fill(x, staticarray_type(T, sz)) +end + + +""" + staticarray_type(T, sz::Tuple{Vararg{StaticInteger}}) + +Returns the type of a static array with element type `T` and size `sz`. +""" +function staticarray_type end + +@inline @generated function staticarray_type(::Type{T}, sz::Tuple{Vararg{StaticInteger,N}}) where {T,N} + szs = map(p -> p.parameters[1], sz.parameters) + len = prod(szs) + :(SArray{Tuple{$szs...},T,$N,$len}) +end + + +""" + MeasureBase.maybestatic_reshape(A, sz) + +Reshapes array `A` to sizes `sz`. + +If `A` is a static array and `sz` is static, the result is a static array. +""" +function maybestatic_reshape end + +maybestatic_reshape(A, sz) = reshape(A, sz) +maybestatic_reshape(A::StaticArray, sz::Tuple{Vararg{StaticInteger}}) = staticarray_type(eltype(A), sz)(Tuple(A)) + + """ MeasureBase.maybestatic_length(x)::IntegerLike @@ -49,13 +120,53 @@ Returns the length of `x` as a dynamic or static integer. """ maybestatic_length(x) = length(x) maybestatic_length(x::AbstractUnitRange) = length(x) -function maybestatic_length(::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}) where {A,B} +maybestatic_length(::Tuple{Vararg{Any,N}}) where N = static(N) +maybestatic_length(nt::NamedTuple) = maybestatic_length(values(nt)) +maybestatic_length(x::StaticArray) = maybestatic_length(maybestatic_eachindex(x)) +maybestatic_length(::StaticArrays.SOneTo{N}) where {N} = static(N) +function maybestatic_length( + ::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}, +) where {A,B} StaticInt{B - A + 1}() end + """ MeasureBase.maybestatic_size(x)::Tuple{Vararg{IntegerLike}} Returns the size of `x` as a tuple of dynamic or static integers. """ -maybestatic_size(x) = size(x) +maybestatic_size(x) = map(maybestatic_length, axes(x)) + + +""" + MeasureBase.maybestatic_eachindex(x) + +Returns the the index range of `x` as a dynamic or static integer range +""" +maybestatic_eachindex(x::AbstractArray) = _conv_static_eachindex(eachindex(x)) +maybestatic_eachindex(::Tuple{Vararg{Any,N}}) where N = static(1):static(N) +maybestatic_eachindex(nt::NamedTuple) = maybestatic_eachindex(values(nt)) + +_conv_static_eachindex(idxs) = idxs +_conv_static_eachindex(::Static.SOneTo{N}) where {N} = static(1):static(N) + + +""" + MeasureBase.maybestatic_first(A) + +Returns the first element of `A` as a dynamic or static value. +""" +maybestatic_first(A::AbstractArray) = first(A) +maybestatic_first(::StaticArrays.SOneTo{N}) where N = static(1) +maybestatic_first(::Static.OptionallyStaticUnitRange{<:Static.StaticInteger{from},<:Static.StaticInteger}) where from = static(from) + + +""" + MeasureBase.maybestatic_last(A) + +Returns the last element of `A` as a dynamic or static value. +""" +maybestatic_last(A::AbstractArray) = last(A) +maybestatic_last(::StaticArrays.SOneTo{N}) where N = static(N) +maybestatic_last(::Static.OptionallyStaticUnitRange{<:Static.StaticInteger,<:Static.StaticInteger{until}}) where until = static(until) diff --git a/src/transport.jl b/src/transport.jl index ce8ce1fd..93506b6b 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -8,6 +8,9 @@ See [`MeasureBase.transport_origin`](@ref). """ struct NoTransportOrigin{NU} end +Base.:^(origin::NoTransportOrigin, ::IntegerLike) = origin + + """ MeasureBase.transport_origin(ν) @@ -76,22 +79,12 @@ and/or * `MeasureBase.from_origin(μ::MyMeasure, x) = y` * `MeasureBase.to_origin(μ::MyMeasure, y) = x` -and ensure `MeasureBase.getdof(μ::MyMeasure)` is defined correctly. - -A standard measure type like `StdUniform`, `StdExponential` or -`StdLogistic` may also be used as the source or target of the transform: - -```julia -f_to_uniform(StdUniform, μ) -f_to_uniform(ν, StdUniform) -``` - -Depending on [`getdof(μ)`](@ref) (resp. `ν`), an instance of the standard -distribution itself or a power of it (e.g. `StdUniform()` or -`StdUniform()^dof`) will be chosen as the transformation partner. +and ensure `MeasureBase.fast_dof(μ::MyMeasure)` is defined correctly. """ function transport_to end +@inline transport_to(ν, μ) = TransportFunction(asmeasure(ν), asmeasure(μ)) + """ transport_to(ν, μ, x) @@ -99,6 +92,7 @@ Transport `x` from the measure `μ` to the measure `ν` """ transport_to(ν, μ, x) = transport_to(ν, μ)(x) + """ transport_def(ν, μ, x) @@ -150,6 +144,13 @@ end μ, x, ) where {n_ν,n_μ} + if n_ν == 10 + return :(throw(ArgumentError("Transport to measure of type $(nameof(typeof(ν))) not supported, origin stack too deep."))) + end + if n_μ == 10 + return :(throw(ArgumentError("Transport from measure of type $(nameof(typeof(μ))) not supported, origin stack too deep."))) + end + prog = quote μ0 = μ x0 = x @@ -186,8 +187,8 @@ end return prog end -@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ)) -@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ +@inline _transport_intermediate(ν, μ) = _transport_intermediate(fast_dof(ν), fast_dof(μ)) +@inline _transport_intermediate(::IntegerLike, n_μ::IntegerLike) = StdUniform()^n_μ @inline _transport_intermediate(::StaticInteger{1}, ::StaticInteger{1}) = StdUniform() _call_transport_def(ν, μ, x) = transport_def(ν, μ, x) @@ -230,8 +231,6 @@ struct TransportFunction{NU,MU} <: Function end end -@inline transport_to(ν, μ) = TransportFunction(ν, μ) - function Base.:(==)(a::TransportFunction, b::TransportFunction) return a.ν == b.ν && a.μ == b.μ end diff --git a/src/utils.jl b/src/utils.jl index 4a0c79a6..1d51be7d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,16 +11,14 @@ showparams(io::IO, nt::NamedTuple) = print(io, nt) export testvalue -@inline testvalue(μ) = rand(FixedRNG(), μ) +@inline testvalue(μ) = rand(ConstantRNG(), μ) -@inline testvalue(::Type{T}, μ) where {T} = rand(FixedRNG(), T, μ) +@inline testvalue(::Type{T}, μ) where {T} = rand(ConstantRNG(), T, μ) testvalue(::Type{T}) where {T} = zero(T) export rootmeasure -basemeasure(μ, x) = basemeasure(μ) - """ rootmeasure(μ::AbstractMeasure) diff --git a/test/Project.toml b/test/Project.toml index f80fdd98..3f04208a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,8 +10,10 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +OneTwoMany = "762dc654-8631-413a-a342-372a7419ad9d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/combinators/bind.jl b/test/combinators/bind.jl new file mode 100644 index 00000000..8ad7bc65 --- /dev/null +++ b/test/combinators/bind.jl @@ -0,0 +1,11 @@ +using Test + +using MeasureBase +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, localmeasure +using MeasureBase: mbind, mintegrate, mintegrate_exp, density_rel, logdensity_rel + +@testset "bind.jl" begin + +end diff --git a/test/combinators/combined.jl b/test/combinators/combined.jl new file mode 100644 index 00000000..7111beb8 --- /dev/null +++ b/test/combinators/combined.jl @@ -0,0 +1,11 @@ +using Test + +using MeasureBase +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, productmeasure, jointmeasure +using MeasureBase: mbind, mintegrate, mintegrate_exp, density_rel, logdensity_rel + +@testset "combined.jl" begin + +end diff --git a/test/measure_operators.jl b/test/measure_operators.jl new file mode 100644 index 00000000..a3adaa8f --- /dev/null +++ b/test/measure_operators.jl @@ -0,0 +1,24 @@ +using Test + +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using MeasureBase.MeasureOperators: ⋄, ⊙, ▷, ⊗, ∫, ∫exp, 𝒹, log𝒹 + +@testset "MeasureOperators" begin + μ = StdExponential() + ν = StdUniform() + k(σ) = pushfwd(x -> σ * x, StdNormal()) + μs = (StdExponential(), StdLogistic(), StdUniform()) + f = sqrt + + @test @inferred(f ⋄ μ) == pushfwd(f, μ) + @test @inferred(ν ⊙ f) == pullbck(f, ν) + @test @inferred(μ ▷ k) == mbind(k, μ) + @test @inferred(⊗(μs...)) == productmeasure(μs) + @test @inferred(∫(f, μ)) == mintegrate(f, μ) + @test @inferred(∫exp(f, μ)) == mintegrate_exp(f, μ) + @test @inferred(𝒹(ν, μ)) == density_rel(ν, μ) + @test @inferred(log𝒹(ν, μ)) == logdensity_rel(ν, μ) +end diff --git a/test/runtests.jl b/test/runtests.jl index f9263b6d..e2423b81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,5 +19,9 @@ include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") +include("combinators/combined.jl") +include("combinators/bind.jl") + +include("measure_operators.jl") include("test_docs.jl") diff --git a/test/static.jl b/test/static.jl index a6c50db2..83e4f930 100644 --- a/test/static.jl +++ b/test/static.jl @@ -11,18 +11,21 @@ import FillArrays @test static(2) isa MeasureBase.IntegerLike @test true isa MeasureBase.IntegerLike @test static(true) isa MeasureBase.IntegerLike - + @test @inferred(MeasureBase.one_to(7)) isa Base.OneTo @test @inferred(MeasureBase.one_to(7)) == 1:7 @test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo @test @inferred(MeasureBase.one_to(static(7))) == static(1):static(7) - @test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5)) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (7,))) == FillArrays.Fill(4.2, 7) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (3, static(7)))) == + FillArrays.Fill(4.2, 3, 7) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (static(3):static(7),))) == + FillArrays.Fill(4.2, (3:7,)) + @test @inferred(MeasureBase.maybestatic_fill(4.2, (3:7, static(2):static(5)))) == + FillArrays.Fill(4.2, (3:7, 2:5)) @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7