Skip to content

Unify {untyped,typed}_{vector_,}varinfo constructor functions #879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,25 @@

**Breaking changes**

### VarInfo constructor
### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.

The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed.
If you were not using this argument (most likely), then there is no change needed.
If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below).

The `UntypedVarInfo` constructor and type is no longer exported.
If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead.

The `TypedVarInfo` constructor and type is no longer exported.
The _type_ has been replaced with `DynamicPPL.NTVarInfo`.
The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`.

Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail.
Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs.
Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface.

### VarName prefixing behaviour

The way in which VarNames in submodels are prefixed has been changed.
Expand Down Expand Up @@ -53,6 +68,20 @@ outer() | (a.x=1.0,)
If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)

**Other changes**

While these are technically breaking, they are only internal changes and do not affect the public API.
The following four functions have been added and/or reworked to make it easier to construct VarInfos with different types of metadata:

1. `DynamicPPL.untyped_varinfo([rng, ]model[, sampler, context])`
2. `DynamicPPL.typed_varinfo([rng, ]model[, sampler, context])`
3. `DynamicPPL.untyped_vector_varinfo([rng, ]model[, sampler, context])`
4. `DynamicPPL.typed_vector_varinfo([rng, ]model[, sampler, context])`

The reason for this change is that there were several flavours of VarInfo.
Some, like `typed_varinfo`, were easy to construct because we had convenience methods for them; however, the others were more difficult.
This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing.

## 0.35.5

Several internal methods have been removed:
Expand Down
10 changes: 4 additions & 6 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ end

Create a benchmark suite for `model` using the selected varinfo type and AD backend.
Available varinfo choices:
• `:untyped` → uses `VarInfo()`
• `:typed` → uses `VarInfo(model)`
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)`
• `:typed` → uses `DynamicPPL.typed_varinfo(model)`
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)

Expand All @@ -67,11 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
suite = BenchmarkGroup()

vi = if varinfo_choice == :untyped
vi = VarInfo()
model(rng, vi)
vi
DynamicPPL.untyped_varinfo(rng, model)
elseif varinfo_choice == :typed
VarInfo(rng, model)
DynamicPPL.typed_varinfo(rng, model)
elseif varinfo_choice == :simple_namedtuple
SimpleVarInfo{Float64}(model(rng))
elseif varinfo_choice == :simple_dict
Expand Down
13 changes: 6 additions & 7 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,17 @@ AbstractVarInfo

But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.

For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:
#### `VarInfo`

```@docs
DynamicPPL.untyped_varinfo
DynamicPPL.typed_varinfo
VarInfo
```

#### `VarInfo`

```@docs
VarInfo
TypedVarInfo
DynamicPPL.untyped_varinfo
DynamicPPL.typed_varinfo
DynamicPPL.untyped_vector_varinfo
DynamicPPL.typed_vector_varinfo
```

One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/).
Expand Down
4 changes: 2 additions & 2 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi

```@example varinfo-design
# Type-unstable
varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped)
varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped)
varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)]
```

```@example varinfo-design
# Type-stable
varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed)
varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed)
varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)]
```

Expand Down
2 changes: 0 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ import Base:
# VarInfo
export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
SimpleVarInfo,
push!!,
empty!!,
Expand Down
8 changes: 4 additions & 4 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ julia> values_as(SimpleVarInfo(data), Vector)
2.0
```

`TypedVarInfo`:
`VarInfo` with `NamedTuple` of `Metadata`:

```jldoctest
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe());

julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

Expand All @@ -273,11 +273,11 @@ julia> values_as(vi, Vector)
2.0
```

`UntypedVarInfo`:
`VarInfo` with `Metadata`:

```jldoctest
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi);
vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe());

julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

Expand Down
2 changes: 1 addition & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function default_varinfo(
context::AbstractContext,
)
init_sampler = initialsampler(sampler)
return VarInfo(rng, model, init_sampler, context)
return typed_varinfo(rng, model, init_sampler, context)
end

function AbstractMCMC.sample(
Expand Down
4 changes: 2 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`.
$(FIELDS)

# Notes
The major differences between this and `TypedVarInfo` are:
The major differences between this and `NTVarInfo` are:
1. `SimpleVarInfo` does not require linearization.
2. `SimpleVarInfo` can use more efficient bijectors.
3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either
Expand Down Expand Up @@ -244,7 +244,7 @@ function SimpleVarInfo{T}(
end

# Constructor from `VarInfo`.
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D}
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
end
function SimpleVarInfo{T}(
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true)
# Typed varinfo.
varinfo_typed = DynamicPPL.TypedVarInfo(varinfo_untyped)
varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true)
@test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true)
end
10 changes: 4 additions & 6 deletions src/test_utils/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ function setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
# VarInfo
vi_untyped_metadata = VarInfo(DynamicPPL.Metadata())
vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector())
model(vi_untyped_metadata)
model(vi_untyped_vnv)
vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata)
vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv)
vi_untyped_metadata = DynamicPPL.untyped_varinfo(model)
vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model)
vi_typed_metadata = DynamicPPL.typed_varinfo(model)
vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model)

# SimpleVarInfo
svi_typed = SimpleVarInfo(example_values)
Expand Down
Loading
Loading