diff --git a/HISTORY.md b/HISTORY.md
index 9a70e8d1f..68650f9d1 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -1,5 +1,23 @@
 # DynamicPPL Changelog
 
+## 0.37.0
+
+**Breaking changes**
+
+### Accumulators
+
+This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:
+
+  - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`.
+  - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself.
+  - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
+  - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
+  - For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
+  - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
+  - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
+  - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
+  - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.
+
 ## 0.36.0
 
 **Breaking changes**
diff --git a/Project.toml b/Project.toml
index 01e2cb612..25c6acd24 100644
--- a/Project.toml
+++ b/Project.toml
@@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
 OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
+Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -68,6 +69,7 @@ MCMCChains = "6"
 MacroTools = "0.5.6"
 Mooncake = "0.4.95"
 OrderedCollections = "1"
+Printf = "1.10"
 Random = "1.6"
 Requires = "1"
 Statistics = "1"
diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl
index 89b65d2de..9661dd505 100644
--- a/benchmarks/benchmarks.jl
+++ b/benchmarks/benchmarks.jl
@@ -100,4 +100,5 @@ PrettyTables.pretty_table(
     header=header,
     tf=PrettyTables.tf_markdown,
     formatters=ft_printf("%.1f", [6, 7]),
+    crop=:none,  # Always print the whole table, even if it doesn't fit in the terminal.
 )
diff --git a/docs/src/api.md b/docs/src/api.md
index 08522e2ce..8e5c64886 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -160,7 +160,7 @@ returned(::Model)
 
 ## Utilities
 
-It is possible to manually increase (or decrease) the accumulated log density from within a model function.
+It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.
 
 ```@docs
 @addlogprob!
@@ -328,9 +328,9 @@ The following functions were used for sequential Monte Carlo methods.
 
 ```@docs
 get_num_produce
-set_num_produce!
-increment_num_produce!
-reset_num_produce!
+set_num_produce!!
+increment_num_produce!!
+reset_num_produce!!
 setorder!
 set_retained_vns_del!
 ```
@@ -345,6 +345,22 @@ Base.empty!
 SimpleVarInfo
 ```
 
+### Accumulators
+
+The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.
+
+```@docs
+AbstractAccumulator
+```
+
+DynamicPPL provides the following default accumulators.
+
+```@docs
+LogPriorAccumulator
+LogLikelihoodAccumulator
+NumProduceAccumulator
+```
+
 ### Common API
 
 #### Accumulation of log-probabilities
@@ -353,6 +369,13 @@ SimpleVarInfo
 getlogp
 setlogp!!
 acclogp!!
+getlogjoint
+getlogprior
+setlogprior!!
+acclogprior!!
+getloglikelihood
+setloglikelihood!!
+accloglikelihood!!
 resetlogp!!
 ```
 
@@ -427,9 +450,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`.
 ```@docs
 SamplingContext
 DefaultContext
-LikelihoodContext
-PriorContext
-MiniBatchContext
 PrefixContext
 ConditionContext
 ```
@@ -476,7 +496,3 @@ DynamicPPL.Experimental.is_suitable_varinfo
 ```@docs
 tilde_assume
 ```
-
-```@docs
-tilde_observe
-```
diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl
index 7fcbd6a7c..70f0f0182 100644
--- a/ext/DynamicPPLMCMCChainsExt.jl
+++ b/ext/DynamicPPLMCMCChainsExt.jl
@@ -48,10 +48,10 @@ end
 Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
 in `chain`, and return the resulting `Chains`.
 
-The `model` passed to `predict` is often different from the one used to generate `chain`. 
-Typically, the model from which `chain` originated treats certain variables as observed (i.e., 
-data points), while the model you pass to `predict` may mark these same variables as missing 
-or unobserved. Calling `predict` then leverages the previously inferred parameter values to 
+The `model` passed to `predict` is often different from the one used to generate `chain`.
+Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
+data points), while the model you pass to `predict` may mark these same variables as missing
+or unobserved. Calling `predict` then leverages the previously inferred parameter values to
 simulate what new, unobserved data might look like, given your posterior beliefs.
 
 For each parameter configuration in `chain`:
@@ -59,7 +59,7 @@ For each parameter configuration in `chain`:
 2. Any variables not included in `chain` are sampled from their prior distributions.
 
 If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
-the samples in `chain`. This is useful when you want to sample only new variables from the posterior 
+the samples in `chain`. This is useful when you want to sample only new variables from the posterior
 predictive distribution.
 
 # Examples
@@ -124,7 +124,7 @@ function DynamicPPL.predict(
             map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
         )
 
-        return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
+        return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
     end
 
     chain_result = reduce(
diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl
index c1c613d08..7527c8be2 100644
--- a/src/DynamicPPL.jl
+++ b/src/DynamicPPL.jl
@@ -6,6 +6,7 @@ using Bijectors
 using Compat
 using Distributions
 using OrderedCollections: OrderedCollections, OrderedDict
+using Printf: Printf
 
 using AbstractMCMC: AbstractMCMC
 using ADTypes: ADTypes
@@ -46,17 +47,28 @@ import Base:
 export AbstractVarInfo,
     VarInfo,
     SimpleVarInfo,
+    AbstractAccumulator,
+    LogLikelihoodAccumulator,
+    LogPriorAccumulator,
+    NumProduceAccumulator,
     push!!,
     empty!!,
     subset,
     getlogp,
+    getlogjoint,
+    getlogprior,
+    getloglikelihood,
     setlogp!!,
+    setlogprior!!,
+    setloglikelihood!!,
     acclogp!!,
+    acclogprior!!,
+    accloglikelihood!!,
     resetlogp!!,
     get_num_produce,
-    set_num_produce!,
-    reset_num_produce!,
-    increment_num_produce!,
+    set_num_produce!!,
+    reset_num_produce!!,
+    increment_num_produce!!,
     set_retained_vns_del!,
     is_flagged,
     set_flag!,
@@ -92,15 +104,10 @@ export AbstractVarInfo,
     # Contexts
     SamplingContext,
     DefaultContext,
-    LikelihoodContext,
-    PriorContext,
-    MiniBatchContext,
     PrefixContext,
     ConditionContext,
     assume,
-    observe,
     tilde_assume,
-    tilde_observe,
     # Pseudo distributions
     NamedDist,
     NoDist,
@@ -146,6 +153,9 @@ macro prob_str(str)
     ))
 end
 
+# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo
+# has to implement. Not sure what the full list is for parameters values, but for
+# accumulators we only need `getaccs` and `setaccs!!`.
 """
     AbstractVarInfo
 
@@ -166,6 +176,8 @@ include("varname.jl")
 include("distribution_wrappers.jl")
 include("contexts.jl")
 include("varnamedvector.jl")
+include("accumulators.jl")
+include("default_accumulators.jl")
 include("abstract_varinfo.jl")
 include("threadsafe.jl")
 include("varinfo.jl")
diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl
index f11b8a3ec..4917a4892 100644
--- a/src/abstract_varinfo.jl
+++ b/src/abstract_varinfo.jl
@@ -90,45 +90,289 @@ Return the `AbstractTransformation` related to `vi`.
 function transformation end
 
 # Accumulation of log-probabilities.
+"""
+    getlogjoint(vi::AbstractVarInfo)
+
+Return the log of the joint probability of the observed data and parameters in `vi`.
+
+See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref).
+"""
+getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi)
+
 """
     getlogp(vi::AbstractVarInfo)
 
-Return the log of the joint probability of the observed data and parameters sampled in
-`vi`.
+Return a NamedTuple of the log prior and log likelihood probabilities.
+
+The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an
+error will be thrown.
+"""
+function getlogp(vi::AbstractVarInfo)
+    return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi))
+end
+
+"""
+    setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple)
+    setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator} where {N})
+
+Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense.
+
+`setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype
+of `AbstractVarInfo`.
+"""
+function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N}
+    return setaccs!!(vi, AccumulatorTuple(accs))
+end
+
+"""
+    getaccs(vi::AbstractVarInfo)
+
+Return the `AccumulatorTuple` of `vi`.
+
+This should be implemented by each subtype of `AbstractVarInfo`.
+"""
+function getaccs end
+
+"""
+    hasacc(vi::AbstractVarInfo, ::Val{accname}) where {accname}
+
+Return a boolean for whether `vi` has an accumulator with name `accname`.
+"""
+hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname)
+function hasacc(vi::AbstractVarInfo, accname::Symbol)
+    return error(
+        """
+        The method hasacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type
+        stability reasons use hasacc(vi::AbstractVarInfo, Val(accname)) instead.
+        """
+    )
+end
+
+"""
+    acckeys(vi::AbstractVarInfo)
+
+Return the names of the accumulators in `vi`.
+"""
+acckeys(vi::AbstractVarInfo) = keys(getaccs(vi))
+
+"""
+    getlogprior(vi::AbstractVarInfo)
+
+Return the log of the prior probability of the parameters in `vi`.
+
+See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ref).
+"""
+getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp
+
+"""
+    getloglikelihood(vi::AbstractVarInfo)
+
+Return the log of the likelihood probability of the observed data in `vi`.
+
+See also: [`getlogjoint`](@ref), [`getlogprior`](@ref), [`setloglikelihood!!`](@ref).
+"""
+getloglikelihood(vi::AbstractVarInfo) = getacc(vi, Val(:LogLikelihood)).logp
+
+"""
+    setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator)
+
+Add `acc` to the `AccumulatorTuple` of `vi`, mutating if it makes sense.
+
+If an accumulator with the same [`accumulator_name`](@ref) already exists, it will be
+replaced.
+
+See also: [`getaccs`](@ref).
+"""
+function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator)
+    return setaccs!!(vi, setacc!!(getaccs(vi), acc))
+end
+
+"""
+    setlogprior!!(vi::AbstractVarInfo, logp)
+
+Set the log of the prior probability of the parameters sampled in `vi` to `logp`.
+
+See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref).
+"""
+setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp))
+
+"""
+    setloglikelihood!!(vi::AbstractVarInfo, logp)
+
+Set the log of the likelihood probability of the observed data sampled in `vi` to `logp`.
+
+See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref).
+"""
+setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihoodAccumulator(logp))
+
+"""
+    setlogp!!(vi::AbstractVarInfo, logp::NamedTuple)
+
+Set both the log prior and the log likelihood probabilities in `vi`.
+
+`logp` should have fields `logprior` and `loglikelihood` and no other fields.
+
+See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref).
+"""
+function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
+    if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior))
+        error("logp must have the fields logprior and loglikelihood and no other fields.")
+    end
+    vi = setlogprior!!(vi, logp.logprior)
+    vi = setloglikelihood!!(vi, logp.loglikelihood)
+    return vi
+end
+
+function setlogp!!(vi::AbstractVarInfo, logp::Number)
+    return error("""
+                 `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use
+                 `setloglikelihood!!`  and/or `setlogprior!!` instead.
+                 """)
+end
+
+"""
+    getacc(vi::AbstractVarInfo, ::Val{accname})
+
+Return the `AbstractAccumulator` of `vi` with name `accname`.
+"""
+function getacc(vi::AbstractVarInfo, accname::Val)
+    return getacc(getaccs(vi), accname)
+end
+function getacc(vi::AbstractVarInfo, accname::Symbol)
+    return error(
+        """
+        The method getacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type
+        stability reasons use getacc(vi::AbstractVarInfo, Val(accname)) instead.
+        """
+    )
+end
+
+"""
+    accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right)
+
+Update all the accumulators of `vi` by calling `accumulate_assume!!` on them.
+"""
+function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right)
+    return map_accumulators!!(acc -> accumulate_assume!!(acc, val, logjac, vn, right), vi)
+end
+
+"""
+    accumulate_observe!!(vi::AbstractVarInfo, right, left, vn)
+
+Update all the accumulators of `vi` by calling `accumulate_observe!!` on them.
 """
-function getlogp end
+function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn)
+    return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi)
+end
 
 """
-    setlogp!!(vi::AbstractVarInfo, logp)
+    map_accumulators!!(func::Function, vi::AbstractVarInfo)
 
-Set the log of the joint probability of the observed data and parameters sampled in
-`vi` to `logp`, mutating if it makes sense.
+Update all accumulators of `vi` by calling `func` on them and replacing them with the return
+values.
 """
-function setlogp!! end
+function map_accumulators!!(func::Function, vi::AbstractVarInfo)
+    return setaccs!!(vi, map(func, getaccs(vi)))
+end
 
 """
-    acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp)
+    map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname}
 
-Add `logp` to the value of the log of the joint probability of the observed data and
-parameters sampled in `vi`, mutating if it makes sense.
+Update the accumulator `accname` of `vi` by calling `func` on it and replacing it with the
+return value.
 """
-function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp!!(NodeTrait(context), context, vi, logp)
+function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Val)
+    return setaccs!!(vi, map_accumulator(func, getaccs(vi), accname))
+end
+
+function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol)
+    return error(
+        """
+        The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol)
+        does not exist. For type stability reasons use
+        map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead.
+        """
+    )
 end
-function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp!!(vi, logp)
+
+"""
+    acclogprior!!(vi::AbstractVarInfo, logp)
+
+Add `logp` to the value of the log of the prior probability in `vi`.
+
+See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref).
+"""
+function acclogprior!!(vi::AbstractVarInfo, logp)
+    return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior))
 end
-function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp!!(childcontext(context), vi, logp)
+
+"""
+    accloglikelihood!!(vi::AbstractVarInfo, logp)
+
+Add `logp` to the value of the log of the likelihood in `vi`.
+
+See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref).
+"""
+function accloglikelihood!!(vi::AbstractVarInfo, logp)
+    return map_accumulator!!(
+        acc -> acc + LogLikelihoodAccumulator(logp), vi, Val(:LogLikelihood)
+    )
+end
+
+"""
+    acclogp!!(vi::AbstractVarInfo, logp::NamedTuple; ignore_missing_accumulator::Bool=false)
+
+Add to both the log prior and the log likelihood probabilities in `vi`.
+
+`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields.
+
+By default if the necessary accumulators are not in `vi` an error is thrown. If
+`ignore_missing_accumulator` is set to `true` then this is silently ignored instead.
+"""
+function acclogp!!(
+    vi::AbstractVarInfo, logp::NamedTuple{names}; ignore_missing_accumulator=false
+) where {names}
+    if !(
+        names == (:logprior, :loglikelihood) ||
+        names == (:loglikelihood, :logprior) ||
+        names == (:logprior,) ||
+        names == (:loglikelihood,)
+    )
+        error("logp must have fields logprior and/or loglikelihood and no other fields.")
+    end
+    if haskey(logp, :logprior) &&
+        (!ignore_missing_accumulator || hasacc(vi, Val(:LogPrior)))
+        vi = acclogprior!!(vi, logp.logprior)
+    end
+    if haskey(logp, :loglikelihood) &&
+        (!ignore_missing_accumulator || hasacc(vi, Val(:LogLikelihood)))
+        vi = accloglikelihood!!(vi, logp.loglikelihood)
+    end
+    return vi
+end
+
+function acclogp!!(vi::AbstractVarInfo, logp::Number)
+    Base.depwarn(
+        "`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.",
+        :acclogp,
+    )
+    return accloglikelihood!!(vi, logp)
 end
 
 """
     resetlogp!!(vi::AbstractVarInfo)
 
-Reset the value of the log of the joint probability of the observed data and parameters
-sampled in `vi` to 0, mutating if it makes sense.
+Reset the values of the log probabilities (prior and likelihood) in `vi` to zero.
 """
-resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi)))
+function resetlogp!!(vi::AbstractVarInfo)
+    if hasacc(vi, Val(:LogPrior))
+        vi = map_accumulator!!(zero, vi, Val(:LogPrior))
+    end
+    if hasacc(vi, Val(:LogLikelihood))
+        vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
+    end
+    return vi
+end
 
 # Variables and their realizations.
 @doc """
@@ -566,8 +810,8 @@ function link!!(
     x = vi[:]
     y, logjac = with_logabsdet_jacobian(b, x)
 
-    lp_new = getlogp(vi) - logjac
-    vi_new = setlogp!!(unflatten(vi, y), lp_new)
+    lp_new = getlogprior(vi) - logjac
+    vi_new = setlogprior!!(unflatten(vi, y), lp_new)
     return settrans!!(vi_new, t)
 end
 
@@ -578,8 +822,8 @@ function invlink!!(
     y = vi[:]
     x, logjac = with_logabsdet_jacobian(b, y)
 
-    lp_new = getlogp(vi) + logjac
-    vi_new = setlogp!!(unflatten(vi, x), lp_new)
+    lp_new = getlogprior(vi) + logjac
+    vi_new = setlogprior!!(unflatten(vi, x), lp_new)
     return settrans!!(vi_new, NoTransformation())
 end
 
@@ -723,9 +967,34 @@ function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
     return x, logpdf(dist, x) + logjac
 end
 
-# Legacy code that is currently overloaded for the sake of simplicity.
-# TODO: Remove when possible.
-increment_num_produce!(::AbstractVarInfo) = nothing
+"""
+    get_num_produce(vi::AbstractVarInfo)
+
+Return the `num_produce` of `vi`.
+"""
+get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num
+
+"""
+    set_num_produce!!(vi::AbstractVarInfo, n::Int)
+
+Set the `num_produce` field of `vi` to `n`.
+"""
+set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n))
+
+"""
+    increment_num_produce!!(vi::AbstractVarInfo)
+
+Add 1 to `num_produce` in `vi`.
+"""
+increment_num_produce!!(vi::AbstractVarInfo) =
+    map_accumulator!!(increment, vi, Val(:NumProduce))
+
+"""
+    reset_num_produce!!(vi::AbstractVarInfo)
+
+Reset the value of `num_produce` in `vi` to 0.
+"""
+reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce))
 
 """
     from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])
diff --git a/src/accumulators.jl b/src/accumulators.jl
new file mode 100644
index 000000000..10a988ae5
--- /dev/null
+++ b/src/accumulators.jl
@@ -0,0 +1,189 @@
+"""
+    AbstractAccumulator
+
+An abstract type for accumulators.
+
+An accumulator is an object that may change its value at every tilde_assume!! or
+tilde_observe!! call based on the random variable in question. The obvious examples of
+accumulators are the log prior and log likelihood. Other examples might be a variable that
+counts the number of observations in a trace, or a list of the names of random variables
+seen so far.
+
+An accumulator type `T <: AbstractAccumulator` must implement the following methods:
+- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
+- `accumulate_observe!!(acc::T, right, left, vn)`
+- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
+
+To be able to work with multi-threading, it should also implement:
+- `split(acc::T)`
+- `combine(acc::T, acc2::T)`
+
+See the documentation for each of these functions for more details.
+"""
+abstract type AbstractAccumulator end
+
+"""
+    accumulator_name(acc::AbstractAccumulator)
+
+Return a Symbol which can be used as a name for `acc`.
+
+The name has to be unique in the sense that a `VarInfo` can only have one accumulator for
+each name. The most typical case, and the default implementation, is that the name only
+depends on the type of `acc`, not on its value.
+"""
+accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc))
+
+"""
+    accumulate_observe!!(acc::AbstractAccumulator, right, left, vn)
+
+Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`.
+
+`vn` is the name of the variable being observed, `left` is the value of the variable, and
+`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case
+of literal observations like `0.0 ~ Normal()`.
+
+`accumulate_observe!!` may mutate `acc`, but not any of the other arguments.
+
+See also: [`accumulate_assume!!`](@ref)
+"""
+function accumulate_observe!! end
+
+"""
+    accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right)
+
+Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`.
+
+`vn` is the name of the variable being assumed, `val` is the value of the variable, and
+`right` is the distribution on the RHS of the tilde statement. `logjac` is the log
+determinant of the Jacobian of the transformation that was done to convert the value of `vn`
+as it was given (e.g. by sampler operating in linked space) to `val`.
+
+`accumulate_assume!!` may mutate `acc`, but not any of the other arguments.
+
+See also: [`accumulate_observe!!`](@ref)
+"""
+function accumulate_assume!! end
+
+"""
+    split(acc::AbstractAccumulator)
+
+Return a new accumulator like `acc` but empty.
+
+The precise meaning of "empty" is that that the returned value should be such that
+`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading
+where different threads may accumulate independently and the results are the combined.
+
+See also: [`combine`](@ref)
+"""
+function split end
+
+"""
+    combine(acc::AbstractAccumulator, acc2::AbstractAccumulator)
+
+Combine two accumulators of the same type. Returns a new accumulator.
+
+See also: [`split`](@ref)
+"""
+function combine end
+
+# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in
+# src/varinfo.jl.
+"""
+    convert_eltype(::Type{T}, acc::AbstractAccumulator)
+
+Convert `acc` to use element type `T`.
+
+What "element type" means depends on the type of `acc`. By default this function does
+nothing. Accumulator types that need to hold differentiable values, such as dual numbers
+used by various AD backends, should implement a method for this function.
+"""
+convert_eltype(::Type, acc::AbstractAccumulator) = acc
+
+"""
+    AccumulatorTuple{N,T<:NamedTuple}
+
+A collection of accumulators, stored as a `NamedTuple` of length `N`
+
+This is defined as a separate type to be able to dispatch on it cleanly and without method
+ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the
+constraint that the name in the tuple for each accumulator `acc` must be
+`accumulator_name(acc)`, and these names must be unique.
+
+The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The
+names will be generated automatically. One can also call the constructor with a `NamedTuple`
+but the names in the argument will be discarded in favour of the generated ones.
+"""
+struct AccumulatorTuple{N,T<:NamedTuple}
+    nt::T
+
+    function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}}
+        names = map(accumulator_name, t)
+        nt = NamedTuple{names}(t)
+        return new{N,typeof(nt)}(nt)
+    end
+end
+
+AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs)
+AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...))
+
+# When showing with text/plain, leave out information about the wrapper AccumulatorTuple.
+Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt)
+Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx]
+Base.length(::AccumulatorTuple{N}) where {N} = N
+Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...)
+function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname}
+    # @inline to ensure constant propagation can resolve this to a compile-time constant.
+    @inline return haskey(at.nt, accname)
+end
+Base.keys(at::AccumulatorTuple) = keys(at.nt)
+
+function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T}
+    return AccumulatorTuple(convert(T, accs.nt))
+end
+
+"""
+    setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
+
+Add `acc` to `at`. Returns a new `AccumulatorTuple`.
+
+If an `AbstractAccumulator` with the same `accumulator_name` already exists in `at` it is
+replaced. `at` will never be mutated, but the name has the `!!` for consistency with the
+corresponding function for `AbstractVarInfo`.
+"""
+function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator)
+    accname = accumulator_name(acc)
+    new_nt = merge(at.nt, NamedTuple{(accname,)}((acc,)))
+    return AccumulatorTuple(new_nt)
+end
+
+"""
+    getacc(at::AccumulatorTuple, ::Val{accname})
+
+Get the accumulator with name `accname` from `at`.
+"""
+function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname}
+    return at[accname]
+end
+
+function Base.map(func::Function, at::AccumulatorTuple)
+    return AccumulatorTuple(map(func, at.nt))
+end
+
+"""
+    map_accumulator(func::Function, at::AccumulatorTuple, ::Val{accname})
+
+Update the accumulator with name `accname` in `at` by calling `func` on it.
+
+Returns a new `AccumulatorTuple`.
+"""
+function map_accumulator(
+    func::Function, at::AccumulatorTuple, ::Val{accname}
+) where {accname}
+    # Would like to write this as
+    # return Accessors.@set at.nt[accname] = func(at[accname], args...)
+    # for readability, but that one isn't type stable due to
+    # https://github.com/JuliaObjects/Accessors.jl/issues/198
+    new_val = func(at[accname])
+    new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,)))
+    return AccumulatorTuple(new_nt)
+end
diff --git a/src/compiler.jl b/src/compiler.jl
index 6f7489b8e..9eb4835d3 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -418,7 +418,7 @@ function generate_tilde_literal(left, right)
     @gensym value
     return quote
         $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
-            __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
+            __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__
         )
         $value
     end
diff --git a/src/context_implementations.jl b/src/context_implementations.jl
index eb025dec8..b92e49fba 100644
--- a/src/context_implementations.jl
+++ b/src/context_implementations.jl
@@ -14,27 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
 require_gradient(spl::Sampler) = false
 require_particles(spl::Sampler) = false
 
-# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline.
-function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp)
-end
-function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp_assume!!(childcontext(context), vi, logp)
-end
-function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp!!(context, vi, logp)
-end
-
-function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp)
-end
-function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp_observe!!(childcontext(context), vi, logp)
-end
-function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp)
-    return acclogp!!(context, vi, logp)
-end
-
 # assume
 """
     tilde_assume(context::SamplingContext, right, vn, vi)
@@ -52,36 +31,18 @@ function tilde_assume(context::SamplingContext, right, vn, vi)
     return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
 end
 
-# Leaf contexts
 function tilde_assume(context::AbstractContext, args...)
-    return tilde_assume(NodeTrait(tilde_assume, context), context, args...)
+    return tilde_assume(childcontext(context), args...)
 end
-function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi)
+function tilde_assume(::DefaultContext, right, vn, vi)
     return assume(right, vn, vi)
 end
-function tilde_assume(::IsParent, context::AbstractContext, args...)
-    return tilde_assume(childcontext(context), args...)
-end
 
 function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
-    return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...)
-end
-function tilde_assume(
-    ::IsLeaf, rng::Random.AbstractRNG, context::AbstractContext, sampler, right, vn, vi
-)
-    return assume(rng, sampler, right, vn, vi)
-end
-function tilde_assume(
-    ::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args...
-)
     return tilde_assume(rng, childcontext(context), args...)
 end
-
-function tilde_assume(::LikelihoodContext, right, vn, vi)
-    return assume(nodist(right), vn, vi)
-end
-function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi)
-    return assume(rng, sampler, nodist(right), vn, vi)
+function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
+    return assume(rng, sampler, right, vn, vi)
 end
 
 function tilde_assume(context::PrefixContext, right, vn, vi)
@@ -137,55 +98,42 @@ function tilde_assume!!(context, right, vn, vi)
         end
         rand_like!!(right, context, vi)
     else
-        value, logp, vi = tilde_assume(context, right, vn, vi)
-        value, acclogp_assume!!(context, vi, logp)
+        value, vi = tilde_assume(context, right, vn, vi)
+        return value, vi
     end
 end
 
 # observe
 """
-    tilde_observe(context::SamplingContext, right, left, vi)
+    tilde_observe!!(context::SamplingContext, right, left, vi)
 
 Handle observed constants with a `context` associated with a sampler.
 
-Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`.
+Falls back to `tilde_observe!!(context.context, right, left, vi)`.
 """
-function tilde_observe(context::SamplingContext, right, left, vi)
-    return tilde_observe(context.context, context.sampler, right, left, vi)
+function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
+    return tilde_observe!!(context.context, right, left, vn, vi)
 end
 
-# Leaf contexts
-function tilde_observe(context::AbstractContext, args...)
-    return tilde_observe(NodeTrait(tilde_observe, context), context, args...)
-end
-tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...)
-function tilde_observe(::IsParent, context::AbstractContext, args...)
-    return tilde_observe(childcontext(context), args...)
-end
-
-tilde_observe(::PriorContext, right, left, vi) = 0, vi
-tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi
-
-# `MiniBatchContext`
-function tilde_observe(context::MiniBatchContext, right, left, vi)
-    logp, vi = tilde_observe(context.context, right, left, vi)
-    return context.loglike_scalar * logp, vi
-end
-function tilde_observe(context::MiniBatchContext, sampler, right, left, vi)
-    logp, vi = tilde_observe(context.context, sampler, right, left, vi)
-    return context.loglike_scalar * logp, vi
+function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
+    return tilde_observe!!(childcontext(context), right, left, vn, vi)
 end
 
 # `PrefixContext`
-function tilde_observe(context::PrefixContext, right, left, vi)
-    return tilde_observe(context.context, right, left, vi)
-end
-function tilde_observe(context::PrefixContext, sampler, right, left, vi)
-    return tilde_observe(context.context, sampler, right, left, vi)
+function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
+    # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal
+    # value. For the need for prefix_and_strip_contexts rather than just prefix, see the
+    # comment in `tilde_assume!!`.
+    new_vn, new_context = if vn !== nothing
+        prefix_and_strip_contexts(context, vn)
+    else
+        vn, childcontext(context)
+    end
+    return tilde_observe!!(new_context, right, left, new_vn, vi)
 end
 
 """
-    tilde_observe!!(context, right, left, vname, vi)
+    tilde_observe!!(context, right, left, vn, vi)
 
 Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
 accumulate the log probability, and return the observed value and updated `vi`.
@@ -193,46 +141,27 @@ accumulate the log probability, and return the observed value and updated `vi`.
 Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
 and indices; if needed, these can be accessed through this function, though.
 """
-function tilde_observe!!(context, right, left, vname, vi)
-    is_rhs_model(right) && throw(
-        ArgumentError(
-            "`~` with a model on the right-hand side of an observe statement is not supported",
-        ),
-    )
-    return tilde_observe!!(context, right, left, vi)
-end
-
-"""
-    tilde_observe(context, right, left, vi)
-
-Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and
-return the observed value.
-
-By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log
-probability of `vi` with the returned value.
-"""
-function tilde_observe!!(context, right, left, vi)
+function tilde_observe!!(context::DefaultContext, right, left, vn, vi)
     is_rhs_model(right) && throw(
         ArgumentError(
             "`~` with a model on the right-hand side of an observe statement is not supported",
         ),
     )
-    logp, vi = tilde_observe(context, right, left, vi)
-    return left, acclogp_observe!!(context, vi, logp)
+    vi = accumulate_observe!!(vi, right, left, vn)
+    return left, vi
 end
 
 function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
     return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
 end
 
-function observe(spl::Sampler, weight)
-    return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))")
-end
-
 # fallback without sampler
 function assume(dist::Distribution, vn::VarName, vi)
-    r, logp = invlink_with_logpdf(vi, vn, dist)
-    return r, logp, vi
+    y = getindex_internal(vi, vn)
+    f = from_maybe_linked_internal_transform(vi, vn, dist)
+    x, logjac = with_logabsdet_jacobian(f, y)
+    vi = accumulate_assume!!(vi, x, logjac, vn, dist)
+    return x, vi
 end
 
 # TODO: Remove this thing.
@@ -254,8 +183,7 @@ function assume(
             r = init(rng, dist, sampler)
             f = to_maybe_linked_internal_transform(vi, vn, dist)
             # TODO(mhauru) This should probably be call a function called setindex_internal!
-            # Also, if we use !! we shouldn't ignore the return value.
-            BangBang.setindex!!(vi, f(r), vn)
+            vi = BangBang.setindex!!(vi, f(r), vn)
             setorder!(vi, vn, get_num_produce(vi))
         else
             # Otherwise we just extract it.
@@ -265,22 +193,16 @@ function assume(
         r = init(rng, dist, sampler)
         if istrans(vi)
             f = to_linked_internal_transform(vi, vn, dist)
-            push!!(vi, vn, f(r), dist)
+            vi = push!!(vi, vn, f(r), dist)
             # By default `push!!` sets the transformed flag to `false`.
-            settrans!!(vi, true, vn)
+            vi = settrans!!(vi, true, vn)
         else
-            push!!(vi, vn, r, dist)
+            vi = push!!(vi, vn, r, dist)
         end
     end
 
     # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
     logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
-    return r, logpdf(dist, r) - logjac, vi
-end
-
-# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
-observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi)
-function observe(right::Distribution, left, vi)
-    increment_num_produce!(vi)
-    return Distributions.loglikelihood(right, left), vi
+    vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
+    return r, vi
 end
diff --git a/src/contexts.jl b/src/contexts.jl
index 8ac085663..addadfa1a 100644
--- a/src/contexts.jl
+++ b/src/contexts.jl
@@ -45,15 +45,17 @@ effectively updating the child context.
 
 # Examples
 ```jldoctest
+julia> using DynamicPPL: DynamicTransformationContext
+
 julia> ctx = SamplingContext();
 
 julia> DynamicPPL.childcontext(ctx)
 DefaultContext()
 
-julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior
+julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}());
 
 julia> DynamicPPL.childcontext(ctx_prior)
-PriorContext()
+DynamicTransformationContext{true}()
 ```
 """
 setchildcontext
@@ -78,7 +80,7 @@ original leaf context of `left`.
 
 # Examples
 ```jldoctest
-julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext
+julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext
 
 julia> struct ParentContext{C} <: AbstractContext
            context::C
@@ -96,8 +98,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext()))
 ParentContext(ParentContext(DefaultContext()))
 
 julia> # Replace the leaf context with another leaf.
-       leafcontext(setleafcontext(ctx, PriorContext()))
-PriorContext()
+       leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}()))
+DynamicTransformationContext{true}()
 
 julia> # Append another parent context.
        setleafcontext(ctx, ParentContext(DefaultContext()))
@@ -129,7 +131,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
 Create a context that allows you to sample parameters with the `sampler` when running the model.
 The `context` determines how the returned log density is computed when running the model.
 
-See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref)
+See also: [`DefaultContext`](@ref)
 """
 struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext
     rng::R
@@ -189,52 +191,11 @@ getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
 """
     struct DefaultContext <: AbstractContext end
 
-The `DefaultContext` is used by default to compute the log joint probability of the data
-and parameters when running the model.
+The `DefaultContext` is used by default to accumulate values like the log joint probability
+when running the model.
 """
 struct DefaultContext <: AbstractContext end
-NodeTrait(context::DefaultContext) = IsLeaf()
-
-"""
-    PriorContext <: AbstractContext
-
-A leaf context resulting in the exclusion of likelihood terms when running the model.
-"""
-struct PriorContext <: AbstractContext end
-NodeTrait(context::PriorContext) = IsLeaf()
-
-"""
-    LikelihoodContext <: AbstractContext
-
-A leaf context resulting in the exclusion of prior terms when running the model.
-"""
-struct LikelihoodContext <: AbstractContext end
-NodeTrait(context::LikelihoodContext) = IsLeaf()
-
-"""
-    struct MiniBatchContext{Tctx, T} <: AbstractContext
-        context::Tctx
-        loglike_scalar::T
-    end
-
-The `MiniBatchContext` enables the computation of
-`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
-`loglike_scalar` field, typically equal to `the number of data points / batch size`.
-This is useful in batch-based stochastic gradient descent algorithms to be optimizing
-`log(prior) + log(likelihood of all the data points)` in the expectation.
-"""
-struct MiniBatchContext{Tctx,T} <: AbstractContext
-    context::Tctx
-    loglike_scalar::T
-end
-function MiniBatchContext(context=DefaultContext(); batch_size, npoints)
-    return MiniBatchContext(context, npoints / batch_size)
-end
-NodeTrait(context::MiniBatchContext) = IsParent()
-childcontext(context::MiniBatchContext) = context.context
-function setchildcontext(parent::MiniBatchContext, child)
-    return MiniBatchContext(child, parent.loglike_scalar)
-end
+NodeTrait(::DefaultContext) = IsLeaf()
 
 """
     PrefixContext(vn::VarName[, context::AbstractContext])
diff --git a/src/debug_utils.jl b/src/debug_utils.jl
index 15ef8fb01..238cd422d 100644
--- a/src/debug_utils.jl
+++ b/src/debug_utils.jl
@@ -76,7 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt
     varname
     right
     value
-    logp
     varinfo = nothing
 end
 
@@ -89,16 +88,12 @@ function Base.show(io::IO, stmt::AssumeStmt)
     print(io, " ")
     print(io, RESULT_SYMBOL)
     print(io, " ")
-    print(io, stmt.value)
-    print(io, " (logprob = ")
-    print(io, stmt.logp)
-    return print(io, ")")
+    return print(io, stmt.value)
 end
 
 Base.@kwdef struct ObserveStmt <: Stmt
     left
     right
-    logp
     varinfo = nothing
 end
 
@@ -107,10 +102,7 @@ function Base.show(io::IO, stmt::ObserveStmt)
     print(io, "observe: ")
     show_right(io, stmt.left)
     print(io, " ~ ")
-    show_right(io, stmt.right)
-    print(io, " (logprob = ")
-    print(io, stmt.logp)
-    return print(io, ")")
+    return show_right(io, stmt.right)
 end
 
 # Some utility methods for extracting information from a trace.
@@ -252,12 +244,11 @@ function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo)
     return nothing
 end
 
-function record_post_tilde_assume!(context::DebugContext, vn, dist, value, logp, varinfo)
+function record_post_tilde_assume!(context::DebugContext, vn, dist, value, varinfo)
     stmt = AssumeStmt(;
         varname=vn,
         right=dist,
         value=value,
-        logp=logp,
         varinfo=context.record_varinfo ? varinfo : nothing,
     )
     if context.record_statements
@@ -268,19 +259,17 @@ end
 
 function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
     record_pre_tilde_assume!(context, vn, right, vi)
-    value, logp, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
-    record_post_tilde_assume!(context, vn, right, value, logp, vi)
-    return value, logp, vi
+    value, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi)
+    record_post_tilde_assume!(context, vn, right, value, vi)
+    return value, vi
 end
 function DynamicPPL.tilde_assume(
     rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
 )
     record_pre_tilde_assume!(context, vn, right, vi)
-    value, logp, vi = DynamicPPL.tilde_assume(
-        rng, childcontext(context), sampler, right, vn, vi
-    )
-    record_post_tilde_assume!(context, vn, right, value, logp, vi)
-    return value, logp, vi
+    value, vi = DynamicPPL.tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
+    record_post_tilde_assume!(context, vn, right, value, vi)
+    return value, vi
 end
 
 # observe
@@ -304,12 +293,9 @@ function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo)
     end
 end
 
-function record_post_tilde_observe!(context::DebugContext, left, right, logp, varinfo)
+function record_post_tilde_observe!(context::DebugContext, left, right, varinfo)
     stmt = ObserveStmt(;
-        left=left,
-        right=right,
-        logp=logp,
-        varinfo=context.record_varinfo ? varinfo : nothing,
+        left=left, right=right, varinfo=context.record_varinfo ? varinfo : nothing
     )
     if context.record_statements
         push!(context.statements, stmt)
@@ -317,17 +303,17 @@ function record_post_tilde_observe!(context::DebugContext, left, right, logp, va
     return nothing
 end
 
-function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi)
+function DynamicPPL.tilde_observe!!(context::DebugContext, right, left, vn, vi)
     record_pre_tilde_observe!(context, left, right, vi)
-    logp, vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi)
-    record_post_tilde_observe!(context, left, right, logp, vi)
-    return logp, vi
+    vi = DynamicPPL.tilde_observe!!(childcontext(context), right, left, vn, vi)
+    record_post_tilde_observe!(context, left, right, vi)
+    return vi
 end
-function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi)
+function DynamicPPL.tilde_observe!!(context::DebugContext, sampler, right, left, vn, vi)
     record_pre_tilde_observe!(context, left, right, vi)
-    logp, vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi)
-    record_post_tilde_observe!(context, left, right, logp, vi)
-    return logp, vi
+    vi = DynamicPPL.tilde_observe!!(childcontext(context), sampler, right, left, vn, vi)
+    record_post_tilde_observe!(context, left, right, vi)
+    return vi
 end
 
 _conditioned_varnames(d::AbstractDict) = keys(d)
@@ -413,7 +399,7 @@ julia> issuccess
 true
 
 julia> print(trace)
- assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356)
+ assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252
 
 julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,));
 
@@ -421,7 +407,7 @@ julia> issuccess
 true
 
 julia> print(trace)
-observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) (logprob = -1.41894)
+observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0)
 ```
 
 ## Incorrect model
diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl
new file mode 100644
index 000000000..ab538ba51
--- /dev/null
+++ b/src/default_accumulators.jl
@@ -0,0 +1,154 @@
+"""
+    LogPriorAccumulator{T<:Real} <: AbstractAccumulator
+
+An accumulator that tracks the cumulative log prior during model execution.
+
+# Fields
+$(TYPEDFIELDS)
+"""
+struct LogPriorAccumulator{T<:Real} <: AbstractAccumulator
+    "the scalar log prior value"
+    logp::T
+end
+
+"""
+    LogPriorAccumulator{T}()
+
+Create a new `LogPriorAccumulator` accumulator with the log prior initialized to zero.
+"""
+LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T))
+LogPriorAccumulator() = LogPriorAccumulator{LogProbType}()
+
+"""
+    LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
+
+An accumulator that tracks the cumulative log likelihood during model execution.
+
+# Fields
+$(TYPEDFIELDS)
+"""
+struct LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
+    "the scalar log likelihood value"
+    logp::T
+end
+
+"""
+    LogLikelihoodAccumulator{T}()
+
+Create a new `LogLikelihoodAccumulator` accumulator with the log likelihood initialized to zero.
+"""
+LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T))
+LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()
+
+"""
+    NumProduceAccumulator{T} <: AbstractAccumulator
+
+An accumulator that tracks the number of observations during model execution.
+
+# Fields
+$(TYPEDFIELDS)
+"""
+struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
+    "the number of observations"
+    num::T
+end
+
+"""
+    NumProduceAccumulator{T<:Integer}()
+
+Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
+"""
+NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T))
+NumProduceAccumulator() = NumProduceAccumulator{Int}()
+
+function Base.show(io::IO, acc::LogPriorAccumulator)
+    return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
+end
+function Base.show(io::IO, acc::LogLikelihoodAccumulator)
+    return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
+end
+function Base.show(io::IO, acc::NumProduceAccumulator)
+    return print(io, "NumProduceAccumulator($(repr(acc.num)))")
+end
+
+accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
+accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
+accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
+
+split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
+split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
+split(acc::NumProduceAccumulator) = acc
+
+function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
+    return LogPriorAccumulator(acc.logp + acc2.logp)
+end
+function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
+    return LogLikelihoodAccumulator(acc.logp + acc2.logp)
+end
+function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
+    return NumProduceAccumulator(max(acc.num, acc2.num))
+end
+
+function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
+    return LogPriorAccumulator(acc1.logp + acc2.logp)
+end
+function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
+    return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
+end
+increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
+
+Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
+Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
+Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))
+
+function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
+    return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
+end
+accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
+
+accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc
+function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
+    # Note that it's important to use the loglikelihood function here, not logpdf, because
+    # they handle vectors differently:
+    # https://github.com/JuliaStats/Distributions.jl/issues/1972
+    return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
+end
+
+accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
+accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
+
+function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
+    return LogPriorAccumulator(convert(T, acc.logp))
+end
+function Base.convert(
+    ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator
+) where {T}
+    return LogLikelihoodAccumulator(convert(T, acc.logp))
+end
+function Base.convert(
+    ::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
+) where {T}
+    return NumProduceAccumulator(convert(T, acc.num))
+end
+
+# TODO(mhauru)
+# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
+# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
+# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
+# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
+function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
+    return LogPriorAccumulator(convert(T, acc.logp))
+end
+function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T}
+    return LogLikelihoodAccumulator(convert(T, acc.logp))
+end
+
+function default_accumulators(
+    ::Type{FloatT}=LogProbType, ::Type{IntT}=Int
+) where {FloatT,IntT}
+    return AccumulatorTuple(
+        LogPriorAccumulator{FloatT}(),
+        LogLikelihoodAccumulator{FloatT}(),
+        NumProduceAccumulator{IntT}(),
+    )
+end
diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl
index a42855f05..1b5e9b8c4 100644
--- a/src/logdensityfunction.jl
+++ b/src/logdensityfunction.jl
@@ -51,7 +51,7 @@ $(FIELDS)
 ```jldoctest
 julia> using Distributions
 
-julia> using DynamicPPL: LogDensityFunction, contextualize
+julia> using DynamicPPL: LogDensityFunction, setaccs!!
 
 julia> @model function demo(x)
            m ~ Normal()
@@ -78,8 +78,8 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
 julia> LogDensityProblems.logdensity(f, [0.0])
 -2.3378770664093453
 
-julia> # This also respects the context in `model`.
-       f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model));
+julia> # LogDensityFunction respects the accumulators in VarInfo:
+       f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
 
 julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
 true
@@ -174,14 +174,26 @@ end
 
 Evaluate the log density of the given `model` at the given parameter values `x`,
 using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
-only for its structure, in the sense that the parameters from the vector `x` are inserted into
-it, and its own parameters are discarded. 
+only for its structure, in the sense that the parameters from the vector `x` are inserted
+into it, and its own parameters are discarded. It does, however, determine whether the log
+prior, likelihood, or joint is returned, based on which accumulators are set in it.
 """
 function logdensity_at(
     x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
 )
     varinfo_new = unflatten(varinfo, x)
-    return getlogp(last(evaluate!!(model, varinfo_new, context)))
+    varinfo_eval = last(evaluate!!(model, varinfo_new, context))
+    has_prior = hasacc(varinfo_eval, Val(:LogPrior))
+    has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood))
+    if has_prior && has_likelihood
+        return getlogjoint(varinfo_eval)
+    elseif has_prior
+        return getlogprior(varinfo_eval)
+    elseif has_likelihood
+        return getloglikelihood(varinfo_eval)
+    else
+        error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood")
+    end
 end
 
 ### LogDensityProblems interface
diff --git a/src/model.jl b/src/model.jl
index c7c4bdf57..3b93fa14d 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -900,7 +900,7 @@ See also: [`evaluate_threadunsafe!!`](@ref)
 function evaluate_threadsafe!!(model, varinfo, context)
     wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo))
     result, wrapper_new = _evaluate!!(model, wrapper, context)
-    return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new))
+    return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new))
 end
 
 """
@@ -1010,7 +1010,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m
 See [`logprior`](@ref) and [`loglikelihood`](@ref).
 """
 function logjoint(model::Model, varinfo::AbstractVarInfo)
-    return getlogp(last(evaluate!!(model, varinfo, DefaultContext())))
+    return getlogjoint(last(evaluate!!(model, varinfo, DefaultContext())))
 end
 
 """
@@ -1057,7 +1057,14 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m
 See also [`logjoint`](@ref) and [`loglikelihood`](@ref).
 """
 function logprior(model::Model, varinfo::AbstractVarInfo)
-    return getlogp(last(evaluate!!(model, varinfo, PriorContext())))
+    # Remove other accumulators from varinfo, since they are unnecessary.
+    logprioracc = if hasacc(varinfo, Val(:LogPrior))
+        getacc(varinfo, Val(:LogPrior))
+    else
+        LogPriorAccumulator()
+    end
+    varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,))
+    return getlogprior(last(evaluate!!(model, varinfo, DefaultContext())))
 end
 
 """
@@ -1104,7 +1111,14 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`.
 See also [`logjoint`](@ref) and [`logprior`](@ref).
 """
 function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
-    return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext())))
+    # Remove other accumulators from varinfo, since they are unnecessary.
+    loglikelihoodacc = if hasacc(varinfo, Val(:LogLikelihood))
+        getacc(varinfo, Val(:LogLikelihood))
+    else
+        LogLikelihoodAccumulator()
+    end
+    varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,))
+    return getloglikelihood(last(evaluate!!(model, varinfo, DefaultContext())))
 end
 
 """
@@ -1358,7 +1372,7 @@ We can check that the log joint probability of the model accumulated in `vi` is
 ```jldoctest submodel-to_submodel
 julia> x = vi[@varname(a.x)];
 
-julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
+julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
 true
 ```
 
@@ -1422,7 +1436,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
 
 julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);
 
-julia> getlogp(vi) ≈ logprior + loglikelihood
+julia> getlogjoint(vi) ≈ logprior + loglikelihood
 true
 ```
 
diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl
index cb9ea4894..b6b97c8f9 100644
--- a/src/pointwise_logdensities.jl
+++ b/src/pointwise_logdensities.jl
@@ -1,142 +1,117 @@
-# Context version
-struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext
-    logdensities::A
-    context::Ctx
-end
+"""
+    PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator
 
-function PointwiseLogdensityContext(
-    likelihoods=OrderedDict{VarName,Vector{Float64}}(),
-    context::AbstractContext=DefaultContext(),
-)
-    return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}(
-        likelihoods, context
-    )
-end
+An accumulator that stores the log-probabilities of each variable in a model.
 
-NodeTrait(::PointwiseLogdensityContext) = IsParent()
-childcontext(context::PointwiseLogdensityContext) = context.context
-function setchildcontext(context::PointwiseLogdensityContext, child)
-    return PointwiseLogdensityContext(context.logdensities, child)
-end
+Internally this context stores the log-probabilities in a dictionary, where the keys are
+the variable names and the values are vectors of log-probabilities. Each element in a vector
+corresponds to one execution of the model.
 
-function Base.push!(
-    context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}},
-    vn::VarName,
-    logp::Real,
-)
-    lookup = context.logdensities
-    ℓ = get!(lookup, vn, Float64[])
-    return push!(ℓ, logp)
+`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies
+which log-probabilities to store in the accumulator. `KeyType` is the type by which variable
+names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary
+used internally to store the log-probabilities, by default
+`OrderedDict{KeyType, Vector{LogProbType}}`.
+"""
+struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <:
+       AbstractAccumulator
+    logps::D
 end
 
-function Base.push!(
-    context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}},
-    vn::VarName,
-    logp::Real,
-)
-    return context.logdensities[vn] = logp
+function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob}
+    return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps)
 end
 
-function Base.push!(
-    context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}},
-    vn::VarName,
-    logp::Real,
-)
-    lookup = context.logdensities
-    ℓ = get!(lookup, string(vn), Float64[])
-    return push!(ℓ, logp)
+function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob}
+    return PointwiseLogProbAccumulator{whichlogprob,VarName}()
 end
 
-function Base.push!(
-    context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}},
-    vn::VarName,
-    logp::Real,
-)
-    return context.logdensities[string(vn)] = logp
+function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType}
+    logps = OrderedDict{KeyType,Vector{LogProbType}}()
+    return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
 end
 
-function Base.push!(
-    context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}},
-    vn::String,
-    logp::Real,
-)
-    lookup = context.logdensities
-    ℓ = get!(lookup, vn, Float64[])
-    return push!(ℓ, logp)
+function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp)
+    logps = acc.logps
+    # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys.
+    T = last(fieldtypes(eltype(logps)))
+    logpvec = get!(logps, vn, T())
+    return push!(logpvec, logp)
 end
 
 function Base.push!(
-    context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}},
-    vn::String,
-    logp::Real,
-)
-    return context.logdensities[vn] = logp
+    acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp
+) where {whichlogprob}
+    return push!(acc, string(vn), logp)
 end
 
-function _include_prior(context::PointwiseLogdensityContext)
-    return leafcontext(context) isa Union{PriorContext,DefaultContext}
-end
-function _include_likelihood(context::PointwiseLogdensityContext)
-    return leafcontext(context) isa Union{LikelihoodContext,DefaultContext}
+function accumulator_name(
+    ::Type{<:PointwiseLogProbAccumulator{whichlogprob}}
+) where {whichlogprob}
+    return Symbol("PointwiseLogProbAccumulator{$whichlogprob}")
 end
 
-function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi)
-    # Defer literal `observe` to child-context.
-    return tilde_observe!!(context.context, right, left, vi)
+function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
+    return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps))
 end
-function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
-    # Completely defer to child context if we are not tracking likelihoods.
-    if !(_include_likelihood(context))
-        return tilde_observe!!(context.context, right, left, vn, vi)
-    end
 
-    # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
-    # we have to intercept the call to `tilde_observe!`.
-    logp, vi = tilde_observe(context.context, right, left, vi)
-
-    # Track loglikelihood value.
-    push!(context, vn, logp)
-
-    return left, acclogp!!(vi, logp)
+function combine(
+    acc::PointwiseLogProbAccumulator{whichlogprob},
+    acc2::PointwiseLogProbAccumulator{whichlogprob},
+) where {whichlogprob}
+    return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps))
 end
 
-# Note on submodels (penelopeysm)
-#
-# We don't need to overload tilde_observe!! for Sampleables (yet), because it
-# is currently not possible to evaluate a model with a Sampleable on the RHS
-# of an observe statement.
-#
-# Note that calling tilde_assume!! on a Sampleable does not necessarily imply
-# that there are no observe statements inside the Sampleable. There could well
-# be likelihood terms in there, which must be included in the returned logp.
-# See e.g. the `demo_dot_assume_observe_submodel` demo model.
-#
-# This is handled by passing the same context to rand_like!!, which figures out
-# which terms to include using the context, and also mutates the context and vi
-# appropriately. Thus, we don't need to check against _include_prior(context)
-# here.
-function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi)
-    value, vi = DynamicPPL.rand_like!!(right, context, vi)
-    return value, vi
+function accumulate_assume!!(
+    acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right
+) where {whichlogprob}
+    if whichlogprob == :both || whichlogprob == :prior
+        # T is the element type of the vectors that are the values of `acc.logps`. Usually
+        # it's LogProbType.
+        T = eltype(last(fieldtypes(eltype(acc.logps))))
+        subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right)
+        push!(acc, vn, subacc.logp)
+    end
+    return acc
 end
 
-function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi)
-    !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi))
-    value, logp, vi = tilde_assume(context.context, right, vn, vi)
-    # Track loglikelihood value.
-    push!(context, vn, logp)
-    return value, acclogp!!(vi, logp)
+function accumulate_observe!!(
+    acc::PointwiseLogProbAccumulator{whichlogprob}, right, left, vn
+) where {whichlogprob}
+    # If `vn` is nothing the LHS of ~ is a literal and we don't have a name to attach this
+    # acc to, and thus do nothing.
+    if vn === nothing
+        return acc
+    end
+    if whichlogprob == :both || whichlogprob == :likelihood
+        # T is the element type of the vectors that are the values of `acc.logps`. Usually
+        # it's LogProbType.
+        T = eltype(last(fieldtypes(eltype(acc.logps))))
+        subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn)
+        push!(acc, vn, subacc.logp)
+    end
+    return acc
 end
 
 """
-    pointwise_logdensities(model::Model, chain::Chains, keytype = String)
+    pointwise_logdensities(
+        model::Model,
+        chain::Chains,
+        keytype=String,
+        context=DefaultContext(),
+        ::Val{whichlogprob}=Val(:both),
+    )
 
 Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
 with keys corresponding to symbols of the variables, and values being matrices
 of shape `(num_chains, num_samples)`.
 
 `keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
-Currently, only `String` and `VarName` are supported.
+Currently, only `String` and `VarName` are supported. `context` is the evaluation context,
+and `whichlogprob` specifies which log-probabilities to compute. It can be `:both`,
+`:prior`, or `:likelihood`.
+
+See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref).
 
 # Notes
 Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
@@ -234,14 +209,19 @@ julia> m = demo([1.0; 1.0]);
 julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
 (-1.4189385332046727, -1.4189385332046727)
 ```
-
 """
 function pointwise_logdensities(
-    model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext()
-) where {T}
+    model::Model,
+    chain,
+    ::Type{KeyType}=String,
+    context::AbstractContext=DefaultContext(),
+    ::Val{whichlogprob}=Val(:both),
+) where {KeyType,whichlogprob}
     # Get the data by executing the model once
     vi = VarInfo(model)
-    point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context)
+
+    AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType}
+    vi = setaccs!!(vi, (AccType(),))
 
     iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
     for (sample_idx, chain_idx) in iters
@@ -249,26 +229,28 @@ function pointwise_logdensities(
         setval!(vi, chain, sample_idx, chain_idx)
 
         # Execute model
-        model(vi, point_context)
+        vi = last(evaluate!!(model, vi, context))
     end
 
+    logps = getacc(vi, Val(accumulator_name(AccType))).logps
     niters = size(chain, 1)
     nchains = size(chain, 3)
     logdensities = OrderedDict(
-        varname => reshape(logliks, niters, nchains) for
-        (varname, logliks) in point_context.logdensities
+        varname => reshape(vals, niters, nchains) for (varname, vals) in logps
     )
     return logdensities
 end
 
 function pointwise_logdensities(
-    model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()
-)
-    point_context = PointwiseLogdensityContext(
-        OrderedDict{VarName,Vector{Float64}}(), context
-    )
-    model(varinfo, point_context)
-    return point_context.logdensities
+    model::Model,
+    varinfo::AbstractVarInfo,
+    context::AbstractContext=DefaultContext(),
+    ::Val{whichlogprob}=Val(:both),
+) where {whichlogprob}
+    AccType = PointwiseLogProbAccumulator{whichlogprob}
+    varinfo = setaccs!!(varinfo, (AccType(),))
+    varinfo = last(evaluate!!(model, varinfo, context))
+    return getacc(varinfo, Val(accumulator_name(AccType))).logps
 end
 
 """
@@ -277,29 +259,19 @@ end
 Compute the pointwise log-likelihoods of the model given the chain.
 This is the same as `pointwise_logdensities(model, chain, context)`, but only
 including the likelihood terms.
-See also: [`pointwise_logdensities`](@ref).
+
+See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref).
 """
 function pointwise_loglikelihoods(
-    model::Model,
-    chain,
-    keytype::Type{T}=String,
-    context::AbstractContext=LikelihoodContext(),
+    model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext()
 ) where {T}
-    if !(leafcontext(context) isa LikelihoodContext)
-        throw(ArgumentError("Leaf context should be a LikelihoodContext"))
-    end
-
-    return pointwise_logdensities(model, chain, T, context)
+    return pointwise_logdensities(model, chain, T, context, Val(:likelihood))
 end
 
 function pointwise_loglikelihoods(
-    model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext()
+    model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()
 )
-    if !(leafcontext(context) isa LikelihoodContext)
-        throw(ArgumentError("Leaf context should be a LikelihoodContext"))
-    end
-
-    return pointwise_logdensities(model, varinfo, context)
+    return pointwise_logdensities(model, varinfo, context, Val(:likelihood))
 end
 
 """
@@ -308,24 +280,17 @@ end
 Compute the pointwise log-prior-densities of the model given the chain.
 This is the same as `pointwise_logdensities(model, chain, context)`, but only
 including the prior terms.
-See also: [`pointwise_logdensities`](@ref).
+
+See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref).
 """
 function pointwise_prior_logdensities(
-    model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext()
+    model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext()
 ) where {T}
-    if !(leafcontext(context) isa PriorContext)
-        throw(ArgumentError("Leaf context should be a PriorContext"))
-    end
-
-    return pointwise_logdensities(model, chain, T, context)
+    return pointwise_logdensities(model, chain, T, context, Val(:prior))
 end
 
 function pointwise_prior_logdensities(
-    model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext()
+    model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext()
 )
-    if !(leafcontext(context) isa PriorContext)
-        throw(ArgumentError("Leaf context should be a PriorContext"))
-    end
-
-    return pointwise_logdensities(model, varinfo, context)
+    return pointwise_logdensities(model, varinfo, context, Val(:prior))
 end
diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl
index abf14b8fc..42fcedfb8 100644
--- a/src/simple_varinfo.jl
+++ b/src/simple_varinfo.jl
@@ -125,18 +125,18 @@ Evaluation in transformed space of course also works:
 
 ```jldoctest simplevarinfo-general
 julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
-Transformed SimpleVarInfo((x = -1.0,), 0.0)
+Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
 
 julia> # (✓) Positive probability mass on negative numbers!
-       getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
+       getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
 -1.3678794411714423
 
 julia> # While if we forget to indicate that it's transformed:
        vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
-SimpleVarInfo((x = -1.0,), 0.0)
+SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
 
 julia> # (✓) No probability mass on negative numbers!
-       getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
+       getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
 -Inf
 ```
 
@@ -188,41 +188,37 @@ ERROR: type NamedTuple has no field b
 [...]
 ```
 """
-struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo
+struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <:
+       AbstractVarInfo
     "underlying representation of the realization represented"
     values::NT
-    "holds the accumulated log-probability"
-    logp::T
+    "tuple of accumulators for things like log prior and log likelihood"
+    accs::Accs
     "represents whether it assumes variables to be transformed"
     transformation::C
 end
 
 transformation(vi::SimpleVarInfo) = vi.transformation
 
-# Makes things a bit more readable vs. putting `Float64` everywhere.
-const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64
-
-function SimpleVarInfo{NT,T}(values, logp) where {NT,T}
-    return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation())
+function SimpleVarInfo(values, accs)
+    return SimpleVarInfo(values, accs, NoTransformation())
 end
-function SimpleVarInfo{T}(θ) where {T<:Real}
-    return SimpleVarInfo{typeof(θ),T}(θ, zero(T))
+function SimpleVarInfo{T}(values) where {T<:Real}
+    return SimpleVarInfo(values, default_accumulators(T))
 end
-
-# Constructors without type-specification.
-SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ)
-function SimpleVarInfo(θ::Union{<:NamedTuple,<:AbstractDict})
-    return if isempty(θ)
+function SimpleVarInfo(values)
+    return SimpleVarInfo{LogProbType}(values)
+end
+function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict})
+    return if isempty(values)
         # Can't infer from values, so we just use default.
-        SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ)
+        SimpleVarInfo{LogProbType}(values)
     else
         # Infer from `values`.
-        SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ)
+        SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values)
     end
 end
 
-SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp)
-
 # Using `kwargs` to specify the values.
 function SimpleVarInfo{T}(; kwargs...) where {T<:Real}
     return SimpleVarInfo{T}(NamedTuple(kwargs))
@@ -235,7 +231,7 @@ end
 function SimpleVarInfo(
     model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
 )
-    return SimpleVarInfo{Float64}(model, args...)
+    return SimpleVarInfo{LogProbType}(model, args...)
 end
 function SimpleVarInfo{T}(
     model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
@@ -244,14 +240,14 @@ function SimpleVarInfo{T}(
 end
 
 # Constructor from `VarInfo`.
-function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D}
-    return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
+function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D}
+    values = values_as(vi, D)
+    return SimpleVarInfo(values, deepcopy(getaccs(vi)))
 end
-function SimpleVarInfo{T}(
-    vi::VarInfo{<:NamedTuple{names}}, ::Type{D}
-) where {T<:Real,names,D}
+function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
     values = values_as(vi, D)
-    return SimpleVarInfo(values, convert(T, getlogp(vi)))
+    accs = map(acc -> convert_eltype(T, acc), getaccs(vi))
+    return SimpleVarInfo(values, accs)
 end
 
 function untyped_simple_varinfo(model::Model)
@@ -265,12 +261,16 @@ function typed_simple_varinfo(model::Model)
 end
 
 function unflatten(svi::SimpleVarInfo, x::AbstractVector)
-    logp = getlogp(svi)
     vals = unflatten(svi.values, x)
-    T = eltype(x)
-    return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}(
-        vals, T(logp), svi.transformation
+    # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is
+    # required but undesireable.
+    # The below line is finicky for type stability. For instance, assigning the eltype to
+    # convert to into an intermediate variable makes this unstable (constant propagation)
+    # fails. Take care when editing.
+    accs = map(
+        acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi)
     )
+    return SimpleVarInfo(vals, accs, svi.transformation)
 end
 
 function BangBang.empty!!(vi::SimpleVarInfo)
@@ -278,21 +278,8 @@ function BangBang.empty!!(vi::SimpleVarInfo)
 end
 Base.isempty(vi::SimpleVarInfo) = isempty(vi.values)
 
-getlogp(vi::SimpleVarInfo) = vi.logp
-getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[]
-
-setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp
-acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp
-
-function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
-    vi.logp[] = logp
-    return vi
-end
-
-function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
-    vi.logp[] += logp
-    return vi
-end
+getaccs(vi::SimpleVarInfo) = vi.accs
+setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs
 
 """
     keys(vi::SimpleVarInfo)
@@ -302,12 +289,12 @@ Return an iterator of keys present in `vi`.
 Base.keys(vi::SimpleVarInfo) = keys(vi.values)
 Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values))
 
-function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo)
+function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo)
     if !(svi.transformation isa NoTransformation)
         print(io, "Transformed ")
     end
 
-    return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")")
+    return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")")
 end
 
 function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
@@ -454,11 +441,11 @@ _subset(x::VarNamedVector, vns) = subset(x, vns)
 # `merge`
 function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
     values = merge(varinfo_left.values, varinfo_right.values)
-    logp = getlogp(varinfo_right)
+    accs = deepcopy(getaccs(varinfo_right))
     transformation = merge_transformations(
         varinfo_left.transformation, varinfo_right.transformation
     )
-    return SimpleVarInfo(values, logp, transformation)
+    return SimpleVarInfo(values, accs, transformation)
 end
 
 # Context implementations
@@ -473,9 +460,11 @@ function assume(
 )
     value = init(rng, dist, sampler)
     # Transform if we're working in unconstrained space.
-    value_raw = to_maybe_linked_internal(vi, vn, dist, value)
+    f = to_maybe_linked_internal_transform(vi, vn, dist)
+    value_raw, logjac = with_logabsdet_jacobian(f, value)
     vi = BangBang.push!!(vi, vn, value_raw, dist)
-    return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
+    vi = accumulate_assume!!(vi, value, -logjac, vn, dist)
+    return value, vi
 end
 
 # NOTE: We don't implement `settrans!!(vi, trans, vn)`.
@@ -497,8 +486,8 @@ islinked(vi::SimpleVarInfo) = istrans(vi)
 
 values_as(vi::SimpleVarInfo) = vi.values
 values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
-function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T}
-    isempty(vi) && return T[]
+function values_as(vi::SimpleVarInfo, ::Type{Vector})
+    isempty(vi) && return Any[]
     return mapreduce(tovec, vcat, values(vi.values))
 end
 function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict}
@@ -613,12 +602,11 @@ function link!!(
     vi::SimpleVarInfo{<:NamedTuple},
     ::Model,
 )
-    # TODO: Make sure that `spl` is respected.
     b = inverse(t.bijector)
     x = vi.values
     y, logjac = with_logabsdet_jacobian(b, x)
-    lp_new = getlogp(vi) - logjac
-    vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new)
+    vi_new = Accessors.@set(vi.values = y)
+    vi_new = acclogprior!!(vi_new, -logjac)
     return settrans!!(vi_new, t)
 end
 
@@ -627,12 +615,11 @@ function invlink!!(
     vi::SimpleVarInfo{<:NamedTuple},
     ::Model,
 )
-    # TODO: Make sure that `spl` is respected.
     b = t.bijector
     y = vi.values
     x, logjac = with_logabsdet_jacobian(b, y)
-    lp_new = getlogp(vi) + logjac
-    vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new)
+    vi_new = Accessors.@set(vi.values = x)
+    vi_new = acclogprior!!(vi_new, logjac)
     return settrans!!(vi_new, NoTransformation())
 end
 
@@ -645,13 +632,4 @@ function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist)
     return invlink_transform(dist)
 end
 
-# Threadsafe stuff.
-# For `SimpleVarInfo` we don't really need `Ref` so let's not use it.
-function ThreadSafeVarInfo(vi::SimpleVarInfo)
-    return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads()))
-end
-function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref})
-    return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
-end
-
 has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector
diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl
index 5f1ec95ec..bd08b427e 100644
--- a/src/submodel_macro.jl
+++ b/src/submodel_macro.jl
@@ -45,7 +45,7 @@ We can check that the log joint probability of the model accumulated in `vi` is
 ```jldoctest submodel
 julia> x = vi[@varname(x)];
 
-julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
+julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
 true
 ```
 """
@@ -124,7 +124,7 @@ julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
 
 julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);
 
-julia> getlogp(vi) ≈ logprior + loglikelihood
+julia> getlogjoint(vi) ≈ logprior + loglikelihood
 true
 ```
 
diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl
index 7404a9af7..08acdfada 100644
--- a/src/test_utils/contexts.jl
+++ b/src/test_utils/contexts.jl
@@ -3,34 +3,6 @@
 #
 # Utilities for testing contexts.
 
-"""
-Context that multiplies each log-prior by mod
-used to test whether varwise_logpriors respects child-context.
-"""
-struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext
-    mod::T
-    context::Ctx
-end
-function TestLogModifyingChildContext(
-    mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext()
-)
-    return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context)
-end
-
-DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
-DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
-function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
-    return TestLogModifyingChildContext(context.mod, child)
-end
-function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi)
-    value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi)
-    return value, logp * context.mod, vi
-end
-function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi)
-    logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi)
-    return logp * context.mod, vi
-end
-
 # Dummy context to test nested behaviors.
 struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext
     context::C
@@ -61,7 +33,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod
 
     # To see change, let's make sure we're using a different leaf context than the current.
     leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext
-        PriorContext()
+        DynamicPPL.DynamicTransformationContext{false}()
     else
         DefaultContext()
     end
diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl
index e29614982..12f88acad 100644
--- a/src/test_utils/models.jl
+++ b/src/test_utils/models.jl
@@ -148,7 +148,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf
     1.5 ~ Normal(m, sqrt(s))
     2.0 ~ Normal(m, sqrt(s))
 
-    return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
+    return (; s, m, x=[1.5, 2.0])
 end
 
 function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)})
@@ -194,7 +194,7 @@ end
     m ~ product_distribution(Normal.(0, sqrt.(s)))
 
     x ~ MvNormal(m, Diagonal(s))
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m)
     return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m))
@@ -225,7 +225,7 @@ end
     end
     x ~ MvNormal(m, Diagonal(s))
 
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m)
     return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m))
@@ -248,7 +248,7 @@ end
     m ~ MvNormal(zero(x), Diagonal(s))
     x ~ MvNormal(m, Diagonal(s))
 
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m)
     s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
@@ -279,7 +279,7 @@ end
         x[i] ~ Normal(m[i], sqrt(s[i]))
     end
 
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m)
     return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m))
@@ -304,7 +304,7 @@ end
     m ~ Normal(0, sqrt(s))
     x .~ Normal(m, sqrt(s))
 
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m)
     return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
@@ -327,7 +327,7 @@ end
     m ~ MvNormal(zeros(2), Diagonal(s))
     [1.5, 2.0] ~ MvNormal(m, Diagonal(s))
 
-    return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=[1.5, 2.0])
 end
 function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
     s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
@@ -358,7 +358,7 @@ end
     1.5 ~ Normal(m[1], sqrt(s[1]))
     2.0 ~ Normal(m[2], sqrt(s[2]))
 
-    return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=[1.5, 2.0])
 end
 function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m)
     return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m))
@@ -384,7 +384,7 @@ end
     1.5 ~ Normal(m, sqrt(s))
     2.0 ~ Normal(m, sqrt(s))
 
-    return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=[1.5, 2.0])
 end
 function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
     return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
@@ -407,7 +407,7 @@ end
     m ~ Normal(0, sqrt(s))
     [1.5, 2.0] .~ Normal(m, sqrt(s))
 
-    return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=[1.5, 2.0])
 end
 function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m)
     return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
@@ -440,7 +440,7 @@ end
     1.5 ~ Normal(m[1], sqrt(s[1]))
     2.0 ~ Normal(m[2], sqrt(s[2]))
 
-    return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=[1.5, 2.0])
 end
 function logprior_true(
     model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m
@@ -476,9 +476,9 @@ end
     # Submodel likelihood
     # With to_submodel, we have to have a left-hand side variable to
     # capture the result, so we just use a dummy variable
-    _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x))
+    _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x), false)
 
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m)
     return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m))
@@ -505,7 +505,7 @@ end
 
     x[:, 1] ~ MvNormal(m, Diagonal(s))
 
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m)
     return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m))
@@ -535,7 +535,7 @@ end
 
     x[:, 1] ~ MvNormal(m, Diagonal(s_vec))
 
-    return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
+    return (; s=s, m=m, x=x)
 end
 function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m)
     n = length(model.args.x)
diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl
index 539872143..07a308c7a 100644
--- a/src/test_utils/varinfo.jl
+++ b/src/test_utils/varinfo.jl
@@ -37,12 +37,6 @@ function setup_varinfos(
     svi_untyped = SimpleVarInfo(OrderedDict())
     svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector())
 
-    # SimpleVarInfo{<:Any,<:Ref}
-    svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed)))
-    svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))
-    svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv)))
-
-    lp = getlogp(vi_typed_metadata)
     varinfos = map((
         vi_untyped_metadata,
         vi_untyped_vnv,
@@ -51,12 +45,10 @@ function setup_varinfos(
         svi_typed,
         svi_untyped,
         svi_vnv,
-        svi_typed_ref,
-        svi_untyped_ref,
-        svi_vnv_ref,
     )) do vi
-        # Set them all to the same values.
-        DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
+        # Set them all to the same values and evaluate logp.
+        vi = update_values!!(vi, example_values, varnames)
+        last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
     end
 
     if include_threadsafe
diff --git a/src/threadsafe.jl b/src/threadsafe.jl
index 2dc2645de..7d2d768a6 100644
--- a/src/threadsafe.jl
+++ b/src/threadsafe.jl
@@ -2,69 +2,79 @@
     ThreadSafeVarInfo
 
 A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an
-array of log probabilities for thread-safe execution of a probabilistic model.
+array of accumulators for thread-safe execution of a probabilistic model.
 """
-struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
+struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo
     varinfo::V
-    logps::L
+    accs_by_thread::Vector{L}
 end
 function ThreadSafeVarInfo(vi::AbstractVarInfo)
-    return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
+    accs_by_thread = [map(split, getaccs(vi)) for _ in 1:Threads.nthreads()]
+    return ThreadSafeVarInfo(vi, accs_by_thread)
 end
 ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
 
-const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{
-    V,<:AbstractArray{<:Ref}
-}
-
 transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo)
 
-# Instead of updating the log probability of the underlying variables we
-# just update the array of log probabilities.
-function acclogp!!(vi::ThreadSafeVarInfo, logp)
-    vi.logps[Threads.threadid()] += logp
-    return vi
+# Set the accumulator in question in vi.varinfo, and set the thread-specific
+# accumulators of the same type to be empty.
+function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator)
+    inner_vi = setacc!!(vi.varinfo, acc)
+    news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread)
+    return ThreadSafeVarInfo(inner_vi, news_accs_by_thread)
 end
-function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp)
-    vi.logps[Threads.threadid()][] += logp
-    return vi
+
+# Get both the main accumulator and the thread-specific accumulators of the same type and
+# combine them.
+function getacc(vi::ThreadSafeVarInfo, accname::Val)
+    main_acc = getacc(vi.varinfo, accname)
+    other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread)
+    return foldl(combine, other_accs; init=main_acc)
 end
 
-# The current log probability of the variables has to be computed from
-# both the wrapped variables and the thread-specific log probabilities.
-getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps)
-getlogp(vi::ThreadSafeVarInfoWithRef) = getlogp(vi.varinfo) + sum(getindex, vi.logps)
+hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname)
+acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo)
 
-# TODO: Make remaining methods thread-safe.
-function resetlogp!!(vi::ThreadSafeVarInfo)
-    return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), zero(vi.logps))
+function getaccs(vi::ThreadSafeVarInfo)
+    # This method is a bit finicky to maintain type stability. For instance, moving the
+    # accname -> Val(accname) part in the main `map` call makes constant propagation fail
+    # and this becomes unstable. Do check the effects if you make edits.
+    accnames = acckeys(vi)
+    accname_vals = map(Val, accnames)
+    return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals))
 end
-function resetlogp!!(vi::ThreadSafeVarInfoWithRef)
-    for x in vi.logps
-        x[] = zero(x[])
-    end
-    return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps)
-end
-function setlogp!!(vi::ThreadSafeVarInfo, logp)
-    return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), zero(vi.logps))
+
+# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that
+# should _not_ be thread-specific a specific method has to be written.
+function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val)
+    tid = Threads.threadid()
+    vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname)
+    return vi
 end
-function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp)
-    for x in vi.logps
-        x[] = zero(x[])
-    end
-    return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps)
+
+function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo)
+    tid = Threads.threadid()
+    vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid])
+    return vi
 end
 
-has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo)
+has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo)
 
 function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution)
     return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist)
 end
 
+# TODO(mhauru) Why these short-circuits? Why not use the thread-specific ones?
 get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo)
-increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo)
-reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo)
-set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n)
+function increment_num_produce!!(vi::ThreadSafeVarInfo)
+    return ThreadSafeVarInfo(increment_num_produce!!(vi.varinfo), vi.accs_by_thread)
+end
+function reset_num_produce!!(vi::ThreadSafeVarInfo)
+    return ThreadSafeVarInfo(reset_num_produce!!(vi.varinfo), vi.accs_by_thread)
+end
+function set_num_produce!!(vi::ThreadSafeVarInfo, n::Int)
+    return ThreadSafeVarInfo(set_num_produce!!(vi.varinfo, n), vi.accs_by_thread)
+end
 
 syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
 
@@ -94,8 +104,8 @@ end
 
 # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
 # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
-# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
-# to define `getlogp(vi)`.
+# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates
+# to define `getacc(vi)`.
 function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
     return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
 end
@@ -130,9 +140,9 @@ end
 
 function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model)
     # Defer to the wrapped `AbstractVarInfo` object.
-    # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the
-    # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in
-    # the `getlogp(vi)`.
+    # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the
+    # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in
+    # the `getlogprior(vi)`.
     return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model)
 end
 
@@ -169,6 +179,23 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo)
     return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo)))
 end
 
+function resetlogp!!(vi::ThreadSafeVarInfo)
+    vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo)
+    for i in eachindex(vi.accs_by_thread)
+        if hasacc(vi, Val(:LogPrior))
+            vi.accs_by_thread[i] = map_accumulator(
+                zero, vi.accs_by_thread[i], Val(:LogPrior)
+            )
+        end
+        if hasacc(vi, Val(:LogLikelihood))
+            vi.accs_by_thread[i] = map_accumulator(
+                zero, vi.accs_by_thread[i], Val(:LogLikelihood)
+            )
+        end
+    end
+    return vi
+end
+
 values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
 values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)
 
diff --git a/src/transforming.jl b/src/transforming.jl
index 429562ec8..ddd1ab59f 100644
--- a/src/transforming.jl
+++ b/src/transforming.jl
@@ -27,18 +27,47 @@ function tilde_assume(
     # Only transform if `!isinverse` since `vi[vn, right]`
     # already performs the inverse transformation if it's transformed.
     r_transformed = isinverse ? r : link_transform(right)(r)
-    return r, lp, setindex!!(vi, r_transformed, vn)
+    if hasacc(vi, Val(:LogPrior))
+        vi = acclogprior!!(vi, lp)
+    end
+    return r, setindex!!(vi, r_transformed, vn)
+end
+
+function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi)
+    return tilde_observe!!(DefaultContext(), right, left, vn, vi)
 end
 
 function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
-    return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
+    return _transform!!(t, DynamicTransformationContext{false}(), vi, model)
 end
 
 function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
-    return settrans!!(
-        last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
-        NoTransformation(),
-    )
+    return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model)
+end
+
+function _transform!!(
+    t::AbstractTransformation,
+    ctx::DynamicTransformationContext,
+    vi::AbstractVarInfo,
+    model::Model,
+)
+    # To transform using DynamicTransformationContext, we evaluate the model, but we do not
+    # need to use any accumulators other than LogPriorAccumulator (which is affected by the Jacobian of
+    # the transformation).
+    accs = getaccs(vi)
+    has_logprior = haskey(accs, Val(:LogPrior))
+    if has_logprior
+        old_logprior = getacc(accs, Val(:LogPrior))
+        vi = setaccs!!(vi, (old_logprior,))
+    end
+    vi = settrans!!(last(evaluate!!(model, vi, ctx)), t)
+    # Restore the accumulators.
+    if has_logprior
+        new_logprior = getacc(vi, Val(:LogPrior))
+        accs = setacc!!(accs, new_logprior)
+    end
+    vi = setaccs!!(vi, accs)
+    return vi
 end
 
 function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
diff --git a/src/utils.jl b/src/utils.jl
index 71919480c..9a9f39ede 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -18,23 +18,29 @@ const LogProbType = float(Real)
 """
     @addlogprob!(ex)
 
-Add the result of the evaluation of `ex` to the joint log probability.
+Add a term to the log joint.
 
-# Examples
+If `ex` evaluates to a `NamedTuple` with keys `:loglikelihood` and/or `:logprior`, the
+values are added to the log likelihood and log prior respectively.
+
+If `ex` evaluates to a number it is added to the log likelihood.
 
-This macro allows you to [include arbitrary terms in the likelihood](https://github.com/TuringLang/Turing.jl/issues/1332)
+# Examples
 
 ```jldoctest; setup = :(using Distributions)
-julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x);
+julia> mylogjoint(x, μ) = (; loglikelihood=loglikelihood(Normal(μ, 1), x), logprior=1.0);
 
 julia> @model function demo(x)
            μ ~ Normal()
-           @addlogprob! myloglikelihood(x, μ)
+           @addlogprob! mylogjoint(x, μ)
        end;
 
 julia> x = [1.3, -2.1];
 
-julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2)
+julia> loglikelihood(demo(x), (μ=0.2,)) ≈ mylogjoint(x, 0.2).loglikelihood
+true
+
+julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogjoint(x, 0.2).logprior
 true
 ```
 
@@ -44,7 +50,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328):
 julia> @model function demo(x)
            m ~ MvNormal(zero(x), I)
            if dot(m, x) < 0
-               @addlogprob! -Inf
+               @addlogprob! (; loglikelihood=-Inf)
                # Exit the model evaluation early
                return
            end
@@ -55,37 +61,22 @@ julia> @model function demo(x)
 julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf
 true
 ```
-
-!!! note
-    The `@addlogprob!` macro increases the accumulated log probability regardless of the evaluation context,
-    i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density.
-    If you would like to avoid this behaviour you should check the evaluation context.
-    It can be accessed with the internal variable `__context__`.
-    For instance, in the following example the log density is not accumulated when only the log prior is computed:
-    ```jldoctest; setup = :(using Distributions)
-    julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x);
-
-    julia> @model function demo(x)
-               μ ~ Normal()
-               if DynamicPPL.leafcontext(__context__) !== PriorContext()
-                   @addlogprob! myloglikelihood(x, μ)
-               end
-           end;
-
-    julia> x = [1.3, -2.1];
-
-    julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2)
-    true
-
-    julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2)
-    true
-    ```
 """
 macro addlogprob!(ex)
     return quote
-        $(esc(:(__varinfo__))) = acclogp!!(
-            $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex))
-        )
+        val = $(esc(ex))
+        vi = $(esc(:(__varinfo__)))
+        if val isa Number
+            if hasacc(vi, Val(:LogLikelihood))
+                $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), val)
+            end
+        elseif val isa NamedTuple
+            $(esc(:(__varinfo__))) = acclogp!!(
+                $(esc(:(__varinfo__))), val; ignore_missing_accumulator=true
+            )
+        else
+            error("logp must be a Number or a NamedTuple.")
+        end
     end
 end
 
diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl
index d3bfd697a..3ec474940 100644
--- a/src/values_as_in_model.jl
+++ b/src/values_as_in_model.jl
@@ -65,29 +65,24 @@ end
 function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
     if is_tracked_value(right)
         value = right.value
-        logp = zero(getlogp(vi))
     else
-        value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
+        value, vi = tilde_assume(childcontext(context), right, vn, vi)
     end
-    # Save the value.
     push!(context, vn, value)
-    # Save the value.
-    # Pass on.
-    return value, logp, vi
+    return value, vi
 end
 function tilde_assume(
     rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
 )
     if is_tracked_value(right)
         value = right.value
-        logp = zero(getlogp(vi))
     else
-        value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
+        value, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
     end
     # Save the value.
     push!(context, vn, value)
     # Pass on.
-    return value, logp, vi
+    return value, vi
 end
 
 """
diff --git a/src/varinfo.jl b/src/varinfo.jl
index 360857ef7..6a968da4d 100644
--- a/src/varinfo.jl
+++ b/src/varinfo.jl
@@ -69,10 +69,9 @@ end
 ###########
 
 """
-    struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo
+    struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo
         metadata::Tmeta
-        logp::Base.RefValue{Tlogp}
-        num_produce::Base.RefValue{Int}
+        accs::Accs
     end
 
 A light wrapper over some kind of metadata.
@@ -98,12 +97,14 @@ Note that for NTVarInfo, it is the user's responsibility to ensure that each
 symbol is visited at least once during model evaluation, regardless of any
 stochastic branching.
 """
-struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo
+struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo
     metadata::Tmeta
-    logp::Base.RefValue{Tlogp}
-    num_produce::Base.RefValue{Int}
+    accs::Accs
 end
-VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0))
+function VarInfo(meta=Metadata())
+    return VarInfo(meta, default_accumulators())
+end
+
 """
     VarInfo([rng, ]model[, sampler, context])
 
@@ -285,10 +286,8 @@ function typed_varinfo(vi::UntypedVarInfo)
             ),
         )
     end
-    logp = getlogp(vi)
-    num_produce = get_num_produce(vi)
     nt = NamedTuple{syms_tuple}(Tuple(new_metas))
-    return VarInfo(nt, Ref(logp), Ref(num_produce))
+    return VarInfo(nt, deepcopy(vi.accs))
 end
 function typed_varinfo(vi::NTVarInfo)
     # This function preserves the behaviour of typed_varinfo(vi) where vi is
@@ -349,8 +348,7 @@ single `VarNamedVector` as its metadata field.
 """
 function untyped_vector_varinfo(vi::UntypedVarInfo)
     md = metadata_to_varnamedvector(vi.metadata)
-    lp = getlogp(vi)
-    return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi)))
+    return VarInfo(md, deepcopy(vi.accs))
 end
 function untyped_vector_varinfo(
     rng::Random.AbstractRNG,
@@ -393,15 +391,12 @@ NamedTuple of `VarNamedVector`s as its metadata field.
 """
 function typed_vector_varinfo(vi::NTVarInfo)
     md = map(metadata_to_varnamedvector, vi.metadata)
-    lp = getlogp(vi)
-    return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi)))
+    return VarInfo(md, deepcopy(vi.accs))
 end
 function typed_vector_varinfo(vi::UntypedVectorVarInfo)
     new_metas = group_by_symbol(vi.metadata)
-    logp = getlogp(vi)
-    num_produce = get_num_produce(vi)
     nt = NamedTuple(new_metas)
-    return VarInfo(nt, Ref(logp), Ref(num_produce))
+    return VarInfo(nt, deepcopy(vi.accs))
 end
 function typed_vector_varinfo(
     rng::Random.AbstractRNG,
@@ -441,13 +436,22 @@ vector_length(md::Metadata) = sum(length, md.ranges)
 
 function unflatten(vi::VarInfo, x::AbstractVector)
     md = unflatten_metadata(vi.metadata, x)
-    # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases
-    # where e.g. x is a type gradient of some AD backend.
-    return VarInfo(
-        md,
-        Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)),
-        Ref(get_num_produce(vi)),
+    # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is
+    # a gradient type of some AD backend.
+    # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!!
+    # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but
+    # the accumulators in the VarInfo are plain floats, we error since we can't change the
+    # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here
+    # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
+    # plain ugly and hacky.
+    # The below line is finicky for type stability. For instance, assigning the eltype to
+    # convert to into an intermediate variable makes this unstable (constant propagation)
+    # fails. Take care when editing.
+    accs = map(
+        acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc),
+        deepcopy(getaccs(vi)),
     )
+    return VarInfo(md, accs)
 end
 
 # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in
@@ -529,7 +533,7 @@ end
 
 function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
     metadata = subset(varinfo.metadata, vns)
-    return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce))
+    return VarInfo(metadata, deepcopy(varinfo.accs))
 end
 
 function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
@@ -618,9 +622,7 @@ end
 
 function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
     metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
-    return VarInfo(
-        metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right))
-    )
+    return VarInfo(metadata, deepcopy(varinfo_right.accs))
 end
 
 function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
@@ -973,8 +975,8 @@ end
 
 function BangBang.empty!!(vi::VarInfo)
     _empty!(vi.metadata)
-    resetlogp!!(vi)
-    reset_num_produce!(vi)
+    vi = resetlogp!!(vi)
+    vi = reset_num_produce!!(vi)
     return vi
 end
 
@@ -1008,46 +1010,8 @@ end
 istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn)
 istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans")
 
-getlogp(vi::VarInfo) = vi.logp[]
-
-function setlogp!!(vi::VarInfo, logp)
-    vi.logp[] = logp
-    return vi
-end
-
-function acclogp!!(vi::VarInfo, logp)
-    vi.logp[] += logp
-    return vi
-end
-
-"""
-    get_num_produce(vi::VarInfo)
-
-Return the `num_produce` of `vi`.
-"""
-get_num_produce(vi::VarInfo) = vi.num_produce[]
-
-"""
-    set_num_produce!(vi::VarInfo, n::Int)
-
-Set the `num_produce` field of `vi` to `n`.
-"""
-set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n
-
-"""
-    increment_num_produce!(vi::VarInfo)
-
-Add 1 to `num_produce` in `vi`.
-"""
-increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1
-
-"""
-    reset_num_produce!(vi::VarInfo)
-
-Reset the value of `num_produce` the log of the joint probability of the observed data
-and parameters sampled in `vi` to 0.
-"""
-reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0)
+getaccs(vi::VarInfo) = vi.accs
+setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs
 
 # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple).
 isempty(vi::VarInfo) = _isempty(vi.metadata)
@@ -1061,7 +1025,7 @@ function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model)
     vns = all_varnames_grouped_by_symbol(vi)
     # If we're working with a `VarNamedVector`, we always use immutable.
     has_varnamedvector(vi) && return _link(model, vi, vns)
-    _link!(vi, vns)
+    vi = _link!!(vi, vns)
     return vi
 end
 
@@ -1069,7 +1033,7 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model)
     vns = keys(vi)
     # If we're working with a `VarNamedVector`, we always use immutable.
     has_varnamedvector(vi) && return _link(model, vi, vns)
-    _link!(vi, vns)
+    vi = _link!!(vi, vns)
     return vi
 end
 
@@ -1082,8 +1046,7 @@ end
 function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model)
     # If we're working with a `VarNamedVector`, we always use immutable.
     has_varnamedvector(vi) && return _link(model, vi, vns)
-    # Call `_link!` instead of `link!` to avoid deprecation warning.
-    _link!(vi, vns)
+    vi = _link!!(vi, vns)
     return vi
 end
 
@@ -1098,27 +1061,28 @@ function link!!(
     return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model)
 end
 
-function _link!(vi::UntypedVarInfo, vns)
+function _link!!(vi::UntypedVarInfo, vns)
     # TODO: Change to a lazy iterator over `vns`
     if ~istrans(vi, vns[1])
         for vn in vns
             f = internal_to_linked_internal_transform(vi, vn)
-            _inner_transform!(vi, vn, f)
-            settrans!!(vi, true, vn)
+            vi = _inner_transform!(vi, vn, f)
+            vi = settrans!!(vi, true, vn)
         end
+        return vi
     else
         @warn("[DynamicPPL] attempt to link a linked vi")
     end
 end
 
-# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a
+# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a
 # NamedTuple that matches the structure of the NTVarInfo.
-function _link!(vi::NTVarInfo, vns::VarNameTuple)
-    return _link!(vi, group_varnames_by_symbol(vns))
+function _link!!(vi::NTVarInfo, vns::VarNameTuple)
+    return _link!!(vi, group_varnames_by_symbol(vns))
 end
 
-function _link!(vi::NTVarInfo, vns::NamedTuple)
-    return _link!(vi.metadata, vi, vns)
+function _link!!(vi::NTVarInfo, vns::NamedTuple)
+    return _link!!(vi.metadata, vi, vns)
 end
 
 """
@@ -1130,7 +1094,7 @@ function filter_subsumed(filter_vns, filtered_vns)
     return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns)
 end
 
-@generated function _link!(
+@generated function _link!!(
     ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names}
 ) where {metadata_names,vns_names}
     expr = Expr(:block)
@@ -1148,8 +1112,8 @@ end
                         # Iterate over all `f_vns` and transform
                         for vn in f_vns
                             f = internal_to_linked_internal_transform(vi, vn)
-                            _inner_transform!(vi, vn, f)
-                            settrans!!(vi, true, vn)
+                            vi = _inner_transform!(vi, vn, f)
+                            vi = settrans!!(vi, true, vn)
                         end
                     else
                         @warn("[DynamicPPL] attempt to link a linked vi")
@@ -1158,6 +1122,7 @@ end
             end,
         )
     end
+    push!(expr.args, :(return vi))
     return expr
 end
 
@@ -1165,8 +1130,7 @@ function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model)
     vns = all_varnames_grouped_by_symbol(vi)
     # If we're working with a `VarNamedVector`, we always use immutable.
     has_varnamedvector(vi) && return _invlink(model, vi, vns)
-    # Call `_invlink!` instead of `invlink!` to avoid deprecation warning.
-    _invlink!(vi, vns)
+    vi = _invlink!!(vi, vns)
     return vi
 end
 
@@ -1174,7 +1138,7 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model)
     vns = keys(vi)
     # If we're working with a `VarNamedVector`, we always use immutable.
     has_varnamedvector(vi) && return _invlink(model, vi, vns)
-    _invlink!(vi, vns)
+    vi = _invlink!!(vi, vns)
     return vi
 end
 
@@ -1187,8 +1151,7 @@ end
 function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model)
     # If we're working with a `VarNamedVector`, we always use immutable.
     has_varnamedvector(vi) && return _invlink(model, vi, vns)
-    # Call `_invlink!` instead of `invlink!` to avoid deprecation warning.
-    _invlink!(vi, vns)
+    vi = _invlink!!(vi, vns)
     return vi
 end
 
@@ -1211,29 +1174,30 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model)
     return maybe_invlink_before_eval!!(t, vi, model)
 end
 
-function _invlink!(vi::UntypedVarInfo, vns)
+function _invlink!!(vi::UntypedVarInfo, vns)
     if istrans(vi, vns[1])
         for vn in vns
             f = linked_internal_to_internal_transform(vi, vn)
-            _inner_transform!(vi, vn, f)
-            settrans!!(vi, false, vn)
+            vi = _inner_transform!(vi, vn, f)
+            vi = settrans!!(vi, false, vn)
         end
+        return vi
     else
         @warn("[DynamicPPL] attempt to invlink an invlinked vi")
     end
 end
 
-# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a
+# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a
 # NamedTuple that matches the structure of the NTVarInfo.
-function _invlink!(vi::NTVarInfo, vns::VarNameTuple)
-    return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns))
+function _invlink!!(vi::NTVarInfo, vns::VarNameTuple)
+    return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns))
 end
 
-function _invlink!(vi::NTVarInfo, vns::NamedTuple)
-    return _invlink!(vi.metadata, vi, vns)
+function _invlink!!(vi::NTVarInfo, vns::NamedTuple)
+    return _invlink!!(vi.metadata, vi, vns)
 end
 
-@generated function _invlink!(
+@generated function _invlink!!(
     ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names}
 ) where {metadata_names,vns_names}
     expr = Expr(:block)
@@ -1251,8 +1215,8 @@ end
                     # Iterate over all `f_vns` and transform
                     for vn in f_vns
                         f = linked_internal_to_internal_transform(vi, vn)
-                        _inner_transform!(vi, vn, f)
-                        settrans!!(vi, false, vn)
+                        vi = _inner_transform!(vi, vn, f)
+                        vi = settrans!!(vi, false, vn)
                     end
                 else
                     @warn("[DynamicPPL] attempt to invlink an invlinked vi")
@@ -1260,6 +1224,7 @@ end
             end,
         )
     end
+    push!(expr.args, :(return vi))
     return expr
 end
 
@@ -1276,7 +1241,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f)
     setrange!(md, vn, start:(start + length(yvec) - 1))
     # Set the new value.
     setval!(md, yvec, vn)
-    acclogp!!(vi, -logjac)
+    vi = acclogprior!!(vi, -logjac)
     return vi
 end
 
@@ -1311,8 +1276,10 @@ end
 
 function _link(model::Model, varinfo::VarInfo, vns)
     varinfo = deepcopy(varinfo)
-    md = _link_metadata!!(model, varinfo, varinfo.metadata, vns)
-    return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
+    md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns)
+    new_varinfo = VarInfo(md, varinfo.accs)
+    new_varinfo = acclogprior!!(new_varinfo, -logjac)
+    return new_varinfo
 end
 
 # If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a
@@ -1323,8 +1290,10 @@ end
 
 function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple)
     varinfo = deepcopy(varinfo)
-    md = _link_metadata!(model, varinfo, varinfo.metadata, vns)
-    return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
+    md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns)
+    new_varinfo = VarInfo(md, varinfo.accs)
+    new_varinfo = acclogprior!!(new_varinfo, -logjac)
+    return new_varinfo
 end
 
 @generated function _link_metadata!(
@@ -1333,20 +1302,39 @@ end
     metadata::NamedTuple{metadata_names},
     vns::NamedTuple{vns_names},
 ) where {metadata_names,vns_names}
-    vals = Expr(:tuple)
+    expr = quote
+        cumulative_logjac = zero(LogProbType)
+    end
+    mds = Expr(:tuple)
     for f in metadata_names
         if f in vns_names
-            push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f)))
+            push!(
+                mds.args,
+                quote
+                    begin
+                        md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f)
+                        cumulative_logjac += logjac
+                        md
+                    end
+                end,
+            )
         else
-            push!(vals.args, :(metadata.$f))
+            push!(mds.args, :(metadata.$f))
         end
     end
 
-    return :(NamedTuple{$metadata_names}($vals))
+    push!(
+        expr.args,
+        quote
+            NamedTuple{$metadata_names}($mds), cumulative_logjac
+        end,
+    )
+    return expr
 end
 
 function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns)
     vns = metadata.vns
+    cumulative_logjac = zero(LogProbType)
 
     # Construct the new transformed values, and keep track of their lengths.
     vals_new = map(vns) do vn
@@ -1364,7 +1352,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_
         # Vectorize value.
         yvec = tovec(y)
         # Accumulate the log-abs-det jacobian correction.
-        acclogp!!(varinfo, -logjac)
+        cumulative_logjac += logjac
         # Mark as transformed.
         settrans!!(varinfo, true, vn)
         # Return the vectorized transformed value.
@@ -1389,7 +1377,8 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_
         metadata.dists,
         metadata.orders,
         metadata.flags,
-    )
+    ),
+    cumulative_logjac
 end
 
 function _link_metadata!!(
@@ -1397,6 +1386,7 @@ function _link_metadata!!(
 )
     vns = target_vns === nothing ? keys(metadata) : target_vns
     dists = extract_priors(model, varinfo)
+    cumulative_logjac = zero(LogProbType)
     for vn in vns
         # First transform from however the variable is stored in vnv to the model
         # representation.
@@ -1409,11 +1399,11 @@ function _link_metadata!!(
         val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig)
         # TODO(mhauru) We are calling a !! function but ignoring the return value.
         # Fix this when attending to issue #653.
-        acclogp!!(varinfo, -logjac1 - logjac2)
+        cumulative_logjac += logjac1 + logjac2
         metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked)
         settrans!(metadata, true, vn)
     end
-    return metadata
+    return metadata, cumulative_logjac
 end
 
 function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model)
@@ -1449,11 +1439,10 @@ end
 
 function _invlink(model::Model, varinfo::VarInfo, vns)
     varinfo = deepcopy(varinfo)
-    return VarInfo(
-        _invlink_metadata!!(model, varinfo, varinfo.metadata, vns),
-        Base.Ref(getlogp(varinfo)),
-        Ref(get_num_produce(varinfo)),
-    )
+    md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns)
+    new_varinfo = VarInfo(md, varinfo.accs)
+    new_varinfo = acclogprior!!(new_varinfo, -logjac)
+    return new_varinfo
 end
 
 # If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a
@@ -1464,8 +1453,10 @@ end
 
 function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple)
     varinfo = deepcopy(varinfo)
-    md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns)
-    return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)))
+    md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns)
+    new_varinfo = VarInfo(md, varinfo.accs)
+    new_varinfo = acclogprior!!(new_varinfo, -logjac)
+    return new_varinfo
 end
 
 @generated function _invlink_metadata!(
@@ -1474,20 +1465,41 @@ end
     metadata::NamedTuple{metadata_names},
     vns::NamedTuple{vns_names},
 ) where {metadata_names,vns_names}
-    vals = Expr(:tuple)
+    expr = quote
+        cumulative_logjac = zero(LogProbType)
+    end
+    mds = Expr(:tuple)
     for f in metadata_names
         if (f in vns_names)
-            push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f)))
+            push!(
+                mds.args,
+                quote
+                    begin
+                        md, logjac = _invlink_metadata!!(
+                            model, varinfo, metadata.$f, vns.$f
+                        )
+                        cumulative_logjac += logjac
+                        md
+                    end
+                end,
+            )
         else
-            push!(vals.args, :(metadata.$f))
+            push!(mds.args, :(metadata.$f))
         end
     end
 
-    return :(NamedTuple{$metadata_names}($vals))
+    push!(
+        expr.args,
+        quote
+            (NamedTuple{$metadata_names}($mds), cumulative_logjac)
+        end,
+    )
+    return expr
 end
 
 function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns)
     vns = metadata.vns
+    cumulative_logjac = zero(LogProbType)
 
     # Construct the new transformed values, and keep track of their lengths.
     vals_new = map(vns) do vn
@@ -1506,7 +1518,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ
         # Vectorize value.
         xvec = tovec(x)
         # Accumulate the log-abs-det jacobian correction.
-        acclogp!!(varinfo, -logjac)
+        cumulative_logjac += logjac
         # Mark as no longer transformed.
         settrans!!(varinfo, false, vn)
         # Return the vectorized transformed value.
@@ -1531,24 +1543,26 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ
         metadata.dists,
         metadata.orders,
         metadata.flags,
-    )
+    ),
+    cumulative_logjac
 end
 
 function _invlink_metadata!!(
     ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns
 )
     vns = target_vns === nothing ? keys(metadata) : target_vns
+    cumulative_logjac = zero(LogProbType)
     for vn in vns
         transform = gettransform(metadata, vn)
         old_val = getindex_internal(metadata, vn)
         new_val, logjac = with_logabsdet_jacobian(transform, old_val)
         # TODO(mhauru) We are calling a !! function but ignoring the return value.
-        acclogp!!(varinfo, -logjac)
+        cumulative_logjac += logjac
         new_transform = from_vec_transform(new_val)
         metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform)
         settrans!(metadata, false, vn)
     end
-    return metadata
+    return metadata, cumulative_logjac
 end
 
 # TODO(mhauru) The treatment of the case when some variables are linked and others are not
@@ -1705,19 +1719,35 @@ function Base.haskey(vi::NTVarInfo, vn::VarName)
 end
 
 function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo)
-    vi_str = """
-    /=======================================================================
-    | VarInfo
-    |-----------------------------------------------------------------------
-    | Varnames  :   $(string(vi.metadata.vns))
-    | Range     :   $(vi.metadata.ranges)
-    | Vals      :   $(vi.metadata.vals)
-    | Orders    :   $(vi.metadata.orders)
-    | Logp      :   $(getlogp(vi))
-    | #produce  :   $(get_num_produce(vi))
-    | flags     :   $(vi.metadata.flags)
-    \\=======================================================================
-    """
+    lines = Tuple{String,Any}[
+        ("VarNames", vi.metadata.vns),
+        ("Range", vi.metadata.ranges),
+        ("Vals", vi.metadata.vals),
+        ("Orders", vi.metadata.orders),
+    ]
+    for accname in acckeys(vi)
+        push!(lines, (string(accname), getacc(vi, Val(accname))))
+    end
+    push!(lines, ("flags", vi.metadata.flags))
+    max_name_length = maximum(map(length ∘ first, lines))
+    fmt = Printf.Format("%-$(max_name_length)s")
+    vi_str = (
+        """
+        /=======================================================================
+        | VarInfo
+        |-----------------------------------------------------------------------
+        """ *
+        prod(
+            map(lines) do (name, value)
+                """
+                | $(Printf.format(fmt, name)) : $(value)
+                """
+            end,
+        ) *
+        """
+        \\=======================================================================
+        """
+    )
     return print(io, vi_str)
 end
 
@@ -1747,7 +1777,11 @@ end
 function Base.show(io::IO, vi::UntypedVarInfo)
     print(io, "VarInfo (")
     _show_varnames(io, vi)
-    print(io, "; logp: ", round(getlogp(vi); digits=3))
+    print(io, "; accumulators: ")
+    # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation
+    # of vi anyway. However, technically `show(io, x)` should give full details of x and
+    # preferably output valid Julia code.
+    show(io, MIME"text/plain"(), getaccs(vi))
     return print(io, ")")
 end
 
diff --git a/test/accumulators.jl b/test/accumulators.jl
new file mode 100644
index 000000000..36bb95e46
--- /dev/null
+++ b/test/accumulators.jl
@@ -0,0 +1,176 @@
+module AccumulatorTests
+
+using Test
+using Distributions
+using DynamicPPL
+using DynamicPPL:
+    AccumulatorTuple,
+    LogLikelihoodAccumulator,
+    LogPriorAccumulator,
+    NumProduceAccumulator,
+    accumulate_assume!!,
+    accumulate_observe!!,
+    combine,
+    convert_eltype,
+    getacc,
+    increment,
+    map_accumulator,
+    setacc!!,
+    split
+
+@testset "accumulators" begin
+    @testset "individual accumulator types" begin
+        @testset "constructors" begin
+            @test LogPriorAccumulator(0.0) ==
+                LogPriorAccumulator() ==
+                LogPriorAccumulator{Float64}() ==
+                LogPriorAccumulator{Float64}(0.0) ==
+                zero(LogPriorAccumulator(1.0))
+            @test LogLikelihoodAccumulator(0.0) ==
+                LogLikelihoodAccumulator() ==
+                LogLikelihoodAccumulator{Float64}() ==
+                LogLikelihoodAccumulator{Float64}(0.0) ==
+                zero(LogLikelihoodAccumulator(1.0))
+            @test NumProduceAccumulator(0) ==
+                NumProduceAccumulator() ==
+                NumProduceAccumulator{Int}() ==
+                NumProduceAccumulator{Int}(0) ==
+                zero(NumProduceAccumulator(1))
+        end
+
+        @testset "addition and incrementation" begin
+            @test LogPriorAccumulator(1.0f0) + LogPriorAccumulator(1.0f0) ==
+                LogPriorAccumulator(2.0f0)
+            @test LogPriorAccumulator(1.0) + LogPriorAccumulator(1.0f0) ==
+                LogPriorAccumulator(2.0)
+            @test LogLikelihoodAccumulator(1.0f0) + LogLikelihoodAccumulator(1.0f0) ==
+                LogLikelihoodAccumulator(2.0f0)
+            @test LogLikelihoodAccumulator(1.0) + LogLikelihoodAccumulator(1.0f0) ==
+                LogLikelihoodAccumulator(2.0)
+            @test increment(NumProduceAccumulator()) == NumProduceAccumulator(1)
+            @test increment(NumProduceAccumulator{UInt8}()) ==
+                NumProduceAccumulator{UInt8}(1)
+        end
+
+        @testset "split and combine" begin
+            for acc in [
+                LogPriorAccumulator(1.0),
+                LogLikelihoodAccumulator(1.0),
+                NumProduceAccumulator(1),
+                LogPriorAccumulator(1.0f0),
+                LogLikelihoodAccumulator(1.0f0),
+                NumProduceAccumulator(UInt8(1)),
+            ]
+                @test combine(acc, split(acc)) == acc
+            end
+        end
+
+        @testset "conversions" begin
+            @test convert(LogPriorAccumulator{Float32}, LogPriorAccumulator(1.0)) ==
+                LogPriorAccumulator{Float32}(1.0f0)
+            @test convert(
+                LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0)
+            ) == LogLikelihoodAccumulator{Float32}(1.0f0)
+            @test convert(NumProduceAccumulator{UInt8}, NumProduceAccumulator(1)) ==
+                NumProduceAccumulator{UInt8}(1)
+
+            @test convert_eltype(Float32, LogPriorAccumulator(1.0)) ==
+                LogPriorAccumulator{Float32}(1.0f0)
+            @test convert_eltype(Float32, LogLikelihoodAccumulator(1.0)) ==
+                LogLikelihoodAccumulator{Float32}(1.0f0)
+        end
+
+        @testset "accumulate_assume" begin
+            val = 2.0
+            logjac = pi
+            vn = @varname(x)
+            dist = Normal()
+            @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) ==
+                LogPriorAccumulator(1.0 + logjac + logpdf(dist, val))
+            @test accumulate_assume!!(
+                LogLikelihoodAccumulator(1.0), val, logjac, vn, dist
+            ) == LogLikelihoodAccumulator(1.0)
+            @test accumulate_assume!!(NumProduceAccumulator(1), val, logjac, vn, dist) ==
+                NumProduceAccumulator(1)
+        end
+
+        @testset "accumulate_observe" begin
+            right = Normal()
+            left = 2.0
+            vn = @varname(x)
+            @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) ==
+                LogPriorAccumulator(1.0)
+            @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) ==
+                LogLikelihoodAccumulator(1.0 + logpdf(right, left))
+            @test accumulate_observe!!(NumProduceAccumulator(1), right, left, vn) ==
+                NumProduceAccumulator(2)
+        end
+    end
+
+    @testset "accumulator tuples" begin
+        # Some accumulators we'll use for testing
+        lp_f64 = LogPriorAccumulator(1.0)
+        lp_f32 = LogPriorAccumulator(1.0f0)
+        ll_f64 = LogLikelihoodAccumulator(1.0)
+        ll_f32 = LogLikelihoodAccumulator(1.0f0)
+        np_i64 = NumProduceAccumulator(1)
+
+        @testset "constructors" begin
+            @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64))
+            # Names in NamedTuple arguments are ignored
+            @test AccumulatorTuple((; a=lp_f64)) == AccumulatorTuple(lp_f64)
+
+            # Can't have two accumulators of the same type.
+            @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f64)
+            # Not even if their element types differ.
+            @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f32)
+        end
+
+        @testset "basic operations" begin
+            at_all64 = AccumulatorTuple(lp_f64, ll_f64, np_i64)
+
+            @test at_all64[:LogPrior] == lp_f64
+            @test at_all64[:LogLikelihood] == ll_f64
+            @test at_all64[:NumProduce] == np_i64
+
+            @test haskey(AccumulatorTuple(np_i64), Val(:NumProduce))
+            @test ~haskey(AccumulatorTuple(np_i64), Val(:LogPrior))
+            @test length(AccumulatorTuple(lp_f64, ll_f64, np_i64)) == 3
+            @test keys(at_all64) == (:LogPrior, :LogLikelihood, :NumProduce)
+            @test collect(at_all64) == [lp_f64, ll_f64, np_i64]
+
+            # Replace the existing LogPriorAccumulator
+            @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32
+            # Check that setacc!! didn't modify the original
+            @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, np_i64)
+            # Add a new accumulator type.
+            @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) ==
+                AccumulatorTuple(lp_f64, ll_f64)
+
+            @test getacc(at_all64, Val(:LogPrior)) == lp_f64
+        end
+
+        @testset "map_accumulator(s)!!" begin
+            # map over all accumulators
+            accs = AccumulatorTuple(lp_f32, ll_f32)
+            @test map(zero, accs) == AccumulatorTuple(
+                LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0)
+            )
+            # Test that the original wasn't modified.
+            @test accs == AccumulatorTuple(lp_f32, ll_f32)
+
+            # A map with a closure that changes the types of the accumulators.
+            @test map(acc -> convert_eltype(Float64, acc), accs) ==
+                AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0))
+
+            # only apply to a particular accumulator
+            @test map_accumulator(zero, accs, Val(:LogLikelihood)) ==
+                AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0))
+            @test map_accumulator(
+                acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood)
+            ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0))
+        end
+    end
+end
+
+end
diff --git a/test/compiler.jl b/test/compiler.jl
index a0286d405..81c018111 100644
--- a/test/compiler.jl
+++ b/test/compiler.jl
@@ -189,12 +189,12 @@ module Issue537 end
             global model_ = __model__
             global context_ = __context__
             global rng_ = __context__.rng
-            global lp = getlogp(__varinfo__)
+            global lp = getlogjoint(__varinfo__)
             return x
         end
         model = testmodel_missing3([1.0])
         varinfo = VarInfo(model)
-        @test getlogp(varinfo) == lp
+        @test getlogjoint(varinfo) == lp
         @test varinfo_ isa AbstractVarInfo
         @test model_ === model
         @test context_ isa SamplingContext
@@ -208,13 +208,13 @@ module Issue537 end
             global model_ = __model__
             global context_ = __context__
             global rng_ = __context__.rng
-            global lp = getlogp(__varinfo__)
+            global lp = getlogjoint(__varinfo__)
             return x
         end false
         lpold = lp
         model = testmodel_missing4([1.0])
         varinfo = VarInfo(model)
-        @test getlogp(varinfo) == lp == lpold
+        @test getlogjoint(varinfo) == lp == lpold
 
         # test DPPL#61
         @model function testmodel_missing5(z)
@@ -333,14 +333,14 @@ module Issue537 end
         function makemodel(p)
             @model function testmodel(x)
                 x[1] ~ Bernoulli(p)
-                global lp = getlogp(__varinfo__)
+                global lp = getlogjoint(__varinfo__)
                 return x
             end
             return testmodel
         end
         model = makemodel(0.5)([1.0])
         varinfo = VarInfo(model)
-        @test getlogp(varinfo) == lp
+        @test getlogjoint(varinfo) == lp
     end
     @testset "user-defined variable name" begin
         @model f1() = x ~ NamedDist(Normal(), :y)
@@ -364,9 +364,9 @@ module Issue537 end
         # TODO(torfjelde): We need conditioning for `Dict`.
         @test_broken f2_c() == 1
         @test_broken f3_c() == 1
-        @test_broken getlogp(VarInfo(f1_c)) ==
-            getlogp(VarInfo(f2_c)) ==
-            getlogp(VarInfo(f3_c))
+        @test_broken getlogjoint(VarInfo(f1_c)) ==
+            getlogjoint(VarInfo(f2_c)) ==
+            getlogjoint(VarInfo(f3_c))
     end
     @testset "custom tilde" begin
         @model demo() = begin
diff --git a/test/context_implementations.jl b/test/context_implementations.jl
index 0ec88c07c..ac6321d69 100644
--- a/test/context_implementations.jl
+++ b/test/context_implementations.jl
@@ -10,7 +10,7 @@
             end
         end
 
-        test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext())
+        test([1, 1, -1])(VarInfo(), SampleFromPrior(), DefaultContext())
     end
 
     @testset "dot tilde with varying sizes" begin
@@ -18,13 +18,14 @@
             @model function test(x, size)
                 y = Array{Float64,length(size)}(undef, size...)
                 y .~ Normal(x)
-                return y, getlogp(__varinfo__)
+                return y
             end
 
             for ysize in ((2,), (2, 3), (2, 3, 4))
                 x = randn()
                 model = test(x, ysize)
-                y, lp = model()
+                y = model()
+                lp = logjoint(model, (; y=y))
                 @test lp ≈ sum(logpdf.(Normal.(x), y))
 
                 ys = [first(model()) for _ in 1:10_000]
diff --git a/test/contexts.jl b/test/contexts.jl
index 1ba099a37..5f22b75eb 100644
--- a/test/contexts.jl
+++ b/test/contexts.jl
@@ -9,7 +9,6 @@ using DynamicPPL:
     NodeTrait,
     IsLeaf,
     IsParent,
-    PointwiseLogdensityContext,
     contextual_isassumption,
     FixedContext,
     ConditionContext,
@@ -47,18 +46,11 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown()
 Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
 
 @testset "contexts.jl" begin
-    child_contexts = Dict(
+    contexts = Dict(
         :default => DefaultContext(),
-        :prior => PriorContext(),
-        :likelihood => LikelihoodContext(),
-    )
-
-    parent_contexts = Dict(
         :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()),
         :sampling => SamplingContext(),
-        :minibatch => MiniBatchContext(DefaultContext(), 0.0),
         :prefix => PrefixContext(@varname(x)),
-        :pointwiselogdensity => PointwiseLogdensityContext(),
         :condition1 => ConditionContext((x=1.0,)),
         :condition2 => ConditionContext(
             (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,)))
@@ -70,8 +62,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
         :condition4 => ConditionContext((x=[1.0, missing],)),
     )
 
-    contexts = merge(child_contexts, parent_contexts)
-
     @testset "$(name)" for (name, context) in contexts
         @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
             DynamicPPL.TestUtils.test_context(context, model)
@@ -235,7 +225,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
                 # Values from outer context should override inner one
                 ctx1 = ConditionContext(n1, ConditionContext(n2))
                 @test ctx1.values == (x=1, y=2)
-                # Check that the two ConditionContexts are collapsed 
+                # Check that the two ConditionContexts are collapsed
                 @test childcontext(ctx1) isa DefaultContext
                 # Then test the nesting the other way round
                 ctx2 = ConditionContext(n2, ConditionContext(n1))
diff --git a/test/independence.jl b/test/independence.jl
deleted file mode 100644
index a4a834a61..000000000
--- a/test/independence.jl
+++ /dev/null
@@ -1,11 +0,0 @@
-@testset "Turing independence" begin
-    @model coinflip(y) = begin
-        p ~ Beta(1, 1)
-        N = length(y)
-        for i in 1:N
-            y[i] ~ Bernoulli(p)
-        end
-    end
-    model = coinflip([1, 1, 0])
-    model(SampleFromPrior(), LikelihoodContext())
-end
diff --git a/test/linking.jl b/test/linking.jl
index d424a9c2d..4f1707263 100644
--- a/test/linking.jl
+++ b/test/linking.jl
@@ -85,7 +85,7 @@ end
                 DynamicPPL.link(vi, model)
             end
             # Difference should just be the log-absdet-jacobian "correction".
-            @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2)
+            @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2)
             @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist])
             # Linked one should be working with a lower-dimensional representation.
             @test length(vi_linked[:]) < length(vi[:])
@@ -98,7 +98,7 @@ end
             end
             @test length(vi_invlinked[:]) == length(vi[:])
             @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist])
-            @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi)
+            @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi)
         end
     end
 
@@ -130,7 +130,7 @@ end
                     end
                     @test length(vi_linked[:]) == d * (d - 1) ÷ 2
                     # Should now include the log-absdet-jacobian correction.
-                    @test !(getlogp(vi_linked) ≈ lp)
+                    @test !(getlogjoint(vi_linked) ≈ lp)
                     # Invlinked.
                     vi_invlinked = if mutable
                         DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@@ -138,7 +138,7 @@ end
                         DynamicPPL.invlink(vi_linked, model)
                     end
                     @test length(vi_invlinked[:]) == d^2
-                    @test getlogp(vi_invlinked) ≈ lp
+                    @test getlogjoint(vi_invlinked) ≈ lp
                 end
             end
         end
@@ -164,7 +164,7 @@ end
                 end
                 @test length(vi_linked[:]) == d - 1
                 # Should now include the log-absdet-jacobian correction.
-                @test !(getlogp(vi_linked) ≈ lp)
+                @test !(getlogjoint(vi_linked) ≈ lp)
                 # Invlinked.
                 vi_invlinked = if mutable
                     DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@@ -172,7 +172,7 @@ end
                     DynamicPPL.invlink(vi_linked, model)
                 end
                 @test length(vi_invlinked[:]) == d
-                @test getlogp(vi_invlinked) ≈ lp
+                @test getlogjoint(vi_invlinked) ≈ lp
             end
         end
     end
diff --git a/test/model.jl b/test/model.jl
index dd5a35fe6..6e4a24ae6 100644
--- a/test/model.jl
+++ b/test/model.jl
@@ -41,7 +41,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
         m = vi[@varname(m)]
 
         # extract log pdf of variable object
-        lp = getlogp(vi)
+        lp = getlogjoint(vi)
 
         # log prior probability
         lprior = logprior(model, vi)
@@ -494,7 +494,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
             varinfo_linked_result = last(
                 DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext())
             )
-            @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result)
+            @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result)
         end
     end
 
diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl
index 61c842638..cfb222b66 100644
--- a/test/pointwise_logdensities.jl
+++ b/test/pointwise_logdensities.jl
@@ -1,6 +1,4 @@
 @testset "logdensities_likelihoods.jl" begin
-    mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2)
-    mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx)
     @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
         example_values = DynamicPPL.TestUtils.rand_prior_true(model)
 
@@ -37,11 +35,6 @@
         lps = pointwise_logdensities(model, vi)
         logp = sum(sum, values(lps))
         @test logp ≈ (logprior_true + loglikelihood_true)
-
-        # Test that modifications of Setup are picked up
-        lps = pointwise_logdensities(model, vi, mod_ctx2)
-        logp = sum(sum, values(lps))
-        @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4
     end
 end
 
diff --git a/test/runtests.jl b/test/runtests.jl
index 72f33f2d0..4a9acf4e1 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -49,13 +49,13 @@ include("test_util.jl")
             include("Aqua.jl")
         end
         include("utils.jl")
+        include("accumulators.jl")
         include("compiler.jl")
         include("varnamedvector.jl")
         include("varinfo.jl")
         include("simple_varinfo.jl")
         include("model.jl")
         include("sampler.jl")
-        include("independence.jl")
         include("distribution_wrappers.jl")
         include("logdensityfunction.jl")
         include("linking.jl")
diff --git a/test/sampler.jl b/test/sampler.jl
index 8c4f1ed96..fe9fd331a 100644
--- a/test/sampler.jl
+++ b/test/sampler.jl
@@ -84,7 +84,7 @@
             let inits = (; p=0.2)
                 chain = sample(model, sampler, 1; initial_params=inits, progress=false)
                 @test chain[1].metadata.p.vals == [0.2]
-                @test getlogp(chain[1]) == lptrue
+                @test getlogjoint(chain[1]) == lptrue
 
                 # parallel sampling
                 chains = sample(
@@ -98,7 +98,7 @@
                 )
                 for c in chains
                     @test c[1].metadata.p.vals == [0.2]
-                    @test getlogp(c[1]) == lptrue
+                    @test getlogjoint(c[1]) == lptrue
                 end
             end
 
@@ -113,7 +113,7 @@
                 chain = sample(model, sampler, 1; initial_params=inits, progress=false)
                 @test chain[1].metadata.s.vals == [4]
                 @test chain[1].metadata.m.vals == [-1]
-                @test getlogp(chain[1]) == lptrue
+                @test getlogjoint(chain[1]) == lptrue
 
                 # parallel sampling
                 chains = sample(
@@ -128,7 +128,7 @@
                 for c in chains
                     @test c[1].metadata.s.vals == [4]
                     @test c[1].metadata.m.vals == [-1]
-                    @test getlogp(c[1]) == lptrue
+                    @test getlogjoint(c[1]) == lptrue
                 end
             end
 
diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl
index 380c24e7d..6f2f39a64 100644
--- a/test/simple_varinfo.jl
+++ b/test/simple_varinfo.jl
@@ -2,12 +2,12 @@
     @testset "constructor & indexing" begin
         @testset "NamedTuple" begin
             svi = SimpleVarInfo(; m=1.0)
-            @test getlogp(svi) == 0.0
+            @test getlogjoint(svi) == 0.0
             @test haskey(svi, @varname(m))
             @test !haskey(svi, @varname(m[1]))
 
             svi = SimpleVarInfo(; m=[1.0])
-            @test getlogp(svi) == 0.0
+            @test getlogjoint(svi) == 0.0
             @test haskey(svi, @varname(m))
             @test haskey(svi, @varname(m[1]))
             @test !haskey(svi, @varname(m[2]))
@@ -21,20 +21,21 @@
             @test !haskey(svi, @varname(m.a.b))
 
             svi = SimpleVarInfo{Float32}(; m=1.0)
-            @test getlogp(svi) isa Float32
+            @test getlogjoint(svi) isa Float32
 
-            svi = SimpleVarInfo((m=1.0,), 1.0)
-            @test getlogp(svi) == 1.0
+            svi = SimpleVarInfo((m=1.0,))
+            svi = accloglikelihood!!(svi, 1.0)
+            @test getlogjoint(svi) == 1.0
         end
 
         @testset "Dict" begin
             svi = SimpleVarInfo(Dict(@varname(m) => 1.0))
-            @test getlogp(svi) == 0.0
+            @test getlogjoint(svi) == 0.0
             @test haskey(svi, @varname(m))
             @test !haskey(svi, @varname(m[1]))
 
             svi = SimpleVarInfo(Dict(@varname(m) => [1.0]))
-            @test getlogp(svi) == 0.0
+            @test getlogjoint(svi) == 0.0
             @test haskey(svi, @varname(m))
             @test haskey(svi, @varname(m[1]))
             @test !haskey(svi, @varname(m[2]))
@@ -59,12 +60,12 @@
 
         @testset "VarNamedVector" begin
             svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0))
-            @test getlogp(svi) == 0.0
+            @test getlogjoint(svi) == 0.0
             @test haskey(svi, @varname(m))
             @test !haskey(svi, @varname(m[1]))
 
             svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0]))
-            @test getlogp(svi) == 0.0
+            @test getlogjoint(svi) == 0.0
             @test haskey(svi, @varname(m))
             @test haskey(svi, @varname(m[1]))
             @test !haskey(svi, @varname(m[2]))
@@ -98,11 +99,10 @@
                 vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn)
             end
             vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
-            lp_orig = getlogp(vi)
 
             # `link!!`
             vi_linked = link!!(deepcopy(vi), model)
-            lp_linked = getlogp(vi_linked)
+            lp_linked = getlogjoint(vi_linked)
             values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(
                 model, values_constrained...
             )
@@ -113,7 +113,7 @@
 
             # `invlink!!`
             vi_invlinked = invlink!!(deepcopy(vi_linked), model)
-            lp_invlinked = getlogp(vi_invlinked)
+            lp_invlinked = getlogjoint(vi_invlinked)
             lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true(
                 model, values_constrained...
             )
@@ -152,7 +152,7 @@
             # DynamicPPL.settrans!!(deepcopy(svi_dict), true),
             # DynamicPPL.settrans!!(deepcopy(svi_vnv), true),
         )
-            # RandOM seed is set in each `@testset`, so we need to sample
+            # Random seed is set in each `@testset`, so we need to sample
             # a new realization for `m` here.
             retval = model()
 
@@ -166,7 +166,7 @@
             end
 
             # Logjoint should be non-zero wp. 1.
-            @test getlogp(svi_new) != 0
+            @test getlogjoint(svi_new) != 0
 
             ### Evaluation ###
             values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
@@ -201,7 +201,7 @@
                 svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn)
             end
 
-            # Reset the logp field.
+            # Reset the logp accumulators.
             svi_eval = DynamicPPL.resetlogp!!(svi_eval)
 
             # Compute `logjoint` using the varinfo.
@@ -250,7 +250,7 @@
                 end
 
                 # `getlogp` should be equal to the logjoint with log-absdet-jac correction.
-                lp = getlogp(svi)
+                lp = getlogjoint(svi)
                 # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375
                 @test lp ≈ lp_true atol = 1.2e-5
             end
@@ -306,7 +306,7 @@
                 DynamicPPL.tovec(retval_unconstrained.m)
 
             # The resulting varinfo should hold the correct logp.
-            lp = getlogp(vi_linked_result)
+            lp = getlogjoint(vi_linked_result)
             @test lp ≈ lp_true
         end
     end
diff --git a/test/submodels.jl b/test/submodels.jl
index e79eed2c3..d3a2f17e7 100644
--- a/test/submodels.jl
+++ b/test/submodels.jl
@@ -35,7 +35,7 @@ using Test
                 @test model()[1] == x_val
                 # Test that the logp was correctly set
                 vi = VarInfo(model)
-                @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)])
+                @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)])
                 # Check the keys
                 @test Set(keys(VarInfo(model))) == Set([@varname(a.y)])
             end
@@ -67,7 +67,7 @@ using Test
                 @test model()[1] == x_val
                 # Test that the logp was correctly set
                 vi = VarInfo(model)
-                @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)])
+                @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)])
                 # Check the keys
                 @test Set(keys(VarInfo(model))) == Set([@varname(y)])
             end
@@ -99,7 +99,7 @@ using Test
                 @test model()[1] == x_val
                 # Test that the logp was correctly set
                 vi = VarInfo(model)
-                @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)])
+                @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)])
                 # Check the keys
                 @test Set(keys(VarInfo(model))) == Set([@varname(b.y)])
             end
@@ -148,7 +148,7 @@ using Test
             # No conditioning
             vi = VarInfo(h())
             @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)])
-            @test getlogp(vi) ==
+            @test getlogjoint(vi) ==
                 logpdf(Normal(), vi[@varname(a.b.x)]) +
                   logpdf(Normal(), vi[@varname(a.b.y)])
 
@@ -174,7 +174,7 @@ using Test
             @testset "$name" for (name, model) in models
                 vi = VarInfo(model)
                 @test Set(keys(vi)) == Set([@varname(a.b.y)])
-                @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)])
+                @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)])
             end
         end
     end
diff --git a/test/threadsafe.jl b/test/threadsafe.jl
index 72c439db8..5b4f6951f 100644
--- a/test/threadsafe.jl
+++ b/test/threadsafe.jl
@@ -4,9 +4,12 @@
         threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)
 
         @test threadsafe_vi.varinfo === vi
-        @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))}
-        @test length(threadsafe_vi.logps) == Threads.nthreads()
-        @test all(iszero(x[]) for x in threadsafe_vi.logps)
+        @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple}
+        @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads()
+        expected_accs = DynamicPPL.AccumulatorTuple(
+            (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))...
+        )
+        @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread)
     end
 
     # TODO: Add more tests of the public API
@@ -14,23 +17,27 @@
         vi = VarInfo(gdemo_default)
         threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi)
 
-        lp = getlogp(vi)
-        @test getlogp(threadsafe_vi) == lp
+        lp = getlogjoint(vi)
+        @test getlogjoint(threadsafe_vi) == lp
 
-        acclogp!!(threadsafe_vi, 42)
-        @test threadsafe_vi.logps[Threads.threadid()][] == 42
-        @test getlogp(vi) == lp
-        @test getlogp(threadsafe_vi) == lp + 42
+        threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42)
+        @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42
+        @test getlogjoint(vi) == lp
+        @test getlogjoint(threadsafe_vi) == lp + 42
 
-        resetlogp!!(threadsafe_vi)
-        @test iszero(getlogp(vi))
-        @test iszero(getlogp(threadsafe_vi))
-        @test all(iszero(x[]) for x in threadsafe_vi.logps)
+        threadsafe_vi = resetlogp!!(threadsafe_vi)
+        @test iszero(getlogjoint(threadsafe_vi))
+        expected_accs = DynamicPPL.AccumulatorTuple(
+            (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))...
+        )
+        @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread)
 
-        setlogp!!(threadsafe_vi, 42)
-        @test getlogp(vi) == 42
-        @test getlogp(threadsafe_vi) == 42
-        @test all(iszero(x[]) for x in threadsafe_vi.logps)
+        threadsafe_vi = setlogprior!!(threadsafe_vi, 42)
+        @test getlogjoint(threadsafe_vi) == 42
+        expected_accs = DynamicPPL.AccumulatorTuple(
+            (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))...
+        )
+        @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread)
     end
 
     @testset "model" begin
@@ -48,7 +55,7 @@
 
         vi = VarInfo()
         wthreads(x)(vi)
-        lp_w_threads = getlogp(vi)
+        lp_w_threads = getlogjoint(vi)
         if Threads.nthreads() == 1
             @test vi_ isa VarInfo
         else
@@ -65,7 +72,7 @@
             vi,
             SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
         )
-        @test getlogp(vi) ≈ lp_w_threads
+        @test getlogjoint(vi) ≈ lp_w_threads
         @test vi_ isa DynamicPPL.ThreadSafeVarInfo
 
         println("  evaluate_threadsafe!!:")
@@ -85,7 +92,7 @@
 
         vi = VarInfo()
         wothreads(x)(vi)
-        lp_wo_threads = getlogp(vi)
+        lp_wo_threads = getlogjoint(vi)
         if Threads.nthreads() == 1
             @test vi_ isa VarInfo
         else
@@ -104,7 +111,7 @@
             vi,
             SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
         )
-        @test getlogp(vi) ≈ lp_w_threads
+        @test getlogjoint(vi) ≈ lp_w_threads
         @test vi_ isa VarInfo
 
         println("  evaluate_threadunsafe!!:")
diff --git a/test/utils.jl b/test/utils.jl
index d683f132d..e4bac14e0 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -1,15 +1,34 @@
 @testset "utils.jl" begin
     @testset "addlogprob!" begin
         @model function testmodel()
-            global lp_before = getlogp(__varinfo__)
+            global lp_before = getlogjoint(__varinfo__)
             @addlogprob!(42)
-            return global lp_after = getlogp(__varinfo__)
+            return global lp_after = getlogjoint(__varinfo__)
         end
 
-        model = testmodel()
-        varinfo = VarInfo(model)
+        varinfo = VarInfo(testmodel())
         @test iszero(lp_before)
-        @test getlogp(varinfo) == lp_after == 42
+        @test getlogjoint(varinfo) == lp_after == 42
+        @test getloglikelihood(varinfo) == 42
+
+        @model function testmodel_nt()
+            global lp_before = getlogjoint(__varinfo__)
+            @addlogprob! (; logprior=(pi + 1), loglikelihood=42)
+            return global lp_after = getlogjoint(__varinfo__)
+        end
+
+        varinfo = VarInfo(testmodel_nt())
+        @test iszero(lp_before)
+        @test getlogjoint(varinfo) == lp_after == 42 + 1 + pi
+        @test getloglikelihood(varinfo) == 42
+        @test getlogprior(varinfo) == pi + 1
+
+        @model function testmodel_nt2()
+            global lp_before = getlogjoint(__varinfo__)
+            llh_nt = (; loglikelihood=42)
+            @addlogprob! llh_nt
+            return global lp_after = getlogjoint(__varinfo__)
+        end
     end
 
     @testset "getargs_dottilde" begin
diff --git a/test/varinfo.jl b/test/varinfo.jl
index 777917aa6..1c597f951 100644
--- a/test/varinfo.jl
+++ b/test/varinfo.jl
@@ -80,7 +80,7 @@ end
 
         function test_base!!(vi_original)
             vi = empty!!(vi_original)
-            @test getlogp(vi) == 0
+            @test getlogjoint(vi) == 0
             @test isempty(vi[:])
 
             vn = @varname x
@@ -123,13 +123,25 @@ end
 
     @testset "get/set/acc/resetlogp" begin
         function test_varinfo_logp!(vi)
-            @test DynamicPPL.getlogp(vi) === 0.0
-            vi = DynamicPPL.setlogp!!(vi, 1.0)
-            @test DynamicPPL.getlogp(vi) === 1.0
-            vi = DynamicPPL.acclogp!!(vi, 1.0)
-            @test DynamicPPL.getlogp(vi) === 2.0
+            @test DynamicPPL.getlogjoint(vi) === 0.0
+            vi = DynamicPPL.setlogprior!!(vi, 1.0)
+            @test DynamicPPL.getlogprior(vi) === 1.0
+            @test DynamicPPL.getloglikelihood(vi) === 0.0
+            @test DynamicPPL.getlogjoint(vi) === 1.0
+            vi = DynamicPPL.acclogprior!!(vi, 1.0)
+            @test DynamicPPL.getlogprior(vi) === 2.0
+            @test DynamicPPL.getloglikelihood(vi) === 0.0
+            @test DynamicPPL.getlogjoint(vi) === 2.0
+            vi = DynamicPPL.setloglikelihood!!(vi, 1.0)
+            @test DynamicPPL.getlogprior(vi) === 2.0
+            @test DynamicPPL.getloglikelihood(vi) === 1.0
+            @test DynamicPPL.getlogjoint(vi) === 3.0
+            vi = DynamicPPL.accloglikelihood!!(vi, 1.0)
+            @test DynamicPPL.getlogprior(vi) === 2.0
+            @test DynamicPPL.getloglikelihood(vi) === 2.0
+            @test DynamicPPL.getlogjoint(vi) === 4.0
             vi = DynamicPPL.resetlogp!!(vi)
-            @test DynamicPPL.getlogp(vi) === 0.0
+            @test DynamicPPL.getlogjoint(vi) === 0.0
         end
 
         vi = VarInfo()
@@ -140,6 +152,98 @@ end
         test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector()))
     end
 
+    @testset "accumulators" begin
+        @model function demo()
+            a ~ Normal()
+            b ~ Normal()
+            c ~ Normal()
+            d ~ Normal()
+            return nothing
+        end
+
+        values = (; a=1.0, b=2.0, c=3.0, d=4.0)
+        lp_a = logpdf(Normal(), values.a)
+        lp_b = logpdf(Normal(), values.b)
+        lp_c = logpdf(Normal(), values.c)
+        lp_d = logpdf(Normal(), values.d)
+        m = demo() | (; c=values.c, d=values.d)
+
+        vi = DynamicPPL.reset_num_produce!!(
+            DynamicPPL.unflatten(VarInfo(m), collect(values))
+        )
+
+        vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi)))
+        @test getlogprior(vi) == lp_a + lp_b
+        @test getloglikelihood(vi) == lp_c + lp_d
+        @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d)
+        @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d
+        @test get_num_produce(vi) == 2
+        @test begin
+            vi = acclogprior!!(vi, 1.0)
+            getlogprior(vi) == lp_a + lp_b + 1.0
+        end
+        @test begin
+            vi = accloglikelihood!!(vi, 1.0)
+            getloglikelihood(vi) == lp_c + lp_d + 1.0
+        end
+        @test begin
+            vi = setlogprior!!(vi, -1.0)
+            getlogprior(vi) == -1.0
+        end
+        @test begin
+            vi = setloglikelihood!!(vi, -1.0)
+            getloglikelihood(vi) == -1.0
+        end
+        @test begin
+            vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0))
+            getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0)
+        end
+        @test begin
+            vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0))
+            getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0)
+        end
+        @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi)
+
+        vi = last(
+            DynamicPPL.evaluate!!(
+                m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorAccumulator(),))
+            ),
+        )
+        @test getlogprior(vi) == lp_a + lp_b
+        @test_throws "has no field LogLikelihood" getloglikelihood(vi)
+        @test_throws "has no field LogLikelihood" getlogp(vi)
+        @test_throws "has no field LogLikelihood" getlogjoint(vi)
+        @test_throws "has no field NumProduce" get_num_produce(vi)
+        @test begin
+            vi = acclogprior!!(vi, 1.0)
+            getlogprior(vi) == lp_a + lp_b + 1.0
+        end
+        @test begin
+            vi = setlogprior!!(vi, -1.0)
+            getlogprior(vi) == -1.0
+        end
+
+        vi = last(
+            DynamicPPL.evaluate!!(
+                m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduceAccumulator(),))
+            ),
+        )
+        @test_throws "has no field LogPrior" getlogprior(vi)
+        @test_throws "has no field LogLikelihood" getloglikelihood(vi)
+        @test_throws "has no field LogPrior" getlogp(vi)
+        @test_throws "has no field LogPrior" getlogjoint(vi)
+        @test get_num_produce(vi) == 2
+
+        # Test evaluating without any accumulators.
+        vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ())))
+        @test_throws "has no field LogPrior" getlogprior(vi)
+        @test_throws "has no field LogLikelihood" getloglikelihood(vi)
+        @test_throws "has no field LogPrior" getlogp(vi)
+        @test_throws "has no field LogPrior" getlogjoint(vi)
+        @test_throws "has no field NumProduce" get_num_produce(vi)
+        @test_throws "has no field NumProduce" reset_num_produce!!(vi)
+    end
+
     @testset "flags" begin
         # Test flag setting:
         #    is_flagged, set_flag!, unset_flag!
@@ -455,12 +559,24 @@ end
 
         ## `untyped_varinfo`
         vi = DynamicPPL.untyped_varinfo(model)
+
+        ## `untyped_varinfo`
+        vi = DynamicPPL.untyped_varinfo(model)
+        vi = DynamicPPL.settrans!!(vi, true, vn)
+        # Sample in unconstrained space.
+        vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext()))
+        f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
+        x = f(DynamicPPL.getindex_internal(vi, vn))
+        @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
+
+        ## `typed_varinfo`
+        vi = DynamicPPL.typed_varinfo(model)
         vi = DynamicPPL.settrans!!(vi, true, vn)
         # Sample in unconstrained space.
         vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext()))
         f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
         x = f(DynamicPPL.getindex_internal(vi, vn))
-        @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
+        @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
 
         ## `typed_varinfo`
         vi = DynamicPPL.typed_varinfo(model)
@@ -469,7 +585,7 @@ end
         vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext()))
         f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
         x = f(DynamicPPL.getindex_internal(vi, vn))
-        @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
+        @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
 
         ### `SimpleVarInfo`
         ## `SimpleVarInfo{<:NamedTuple}`
@@ -478,7 +594,7 @@ end
         vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext()))
         f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
         x = f(DynamicPPL.getindex_internal(vi, vn))
-        @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
+        @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
 
         ## `SimpleVarInfo{<:Dict}`
         vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true)
@@ -486,7 +602,7 @@ end
         vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext()))
         f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
         x = f(DynamicPPL.getindex_internal(vi, vn))
-        @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
+        @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
 
         ## `SimpleVarInfo{<:VarNamedVector}`
         vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true)
@@ -494,7 +610,7 @@ end
         vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext()))
         f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
         x = f(DynamicPPL.getindex_internal(vi, vn))
-        @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
+        @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true)
     end
 
     @testset "values_as" begin
@@ -596,8 +712,8 @@ end
 
                     lp = logjoint(model, varinfo)
                     @test lp ≈ lp_true
-                    @test getlogp(varinfo) ≈ lp_true
-                    lp_linked = getlogp(varinfo_linked)
+                    @test getlogjoint(varinfo) ≈ lp_true
+                    lp_linked = getlogjoint(varinfo_linked)
                     @test lp_linked ≈ lp_linked_true
 
                     # TODO: Compare values once we are no longer working with `NamedTuple` for
@@ -609,13 +725,36 @@ end
                             varinfo_linked_unflattened, model
                         )
                         @test length(varinfo_invlinked[:]) == length(varinfo[:])
-                        @test getlogp(varinfo_invlinked) ≈ lp_true
+                        @test getlogjoint(varinfo_invlinked) ≈ lp_true
                     end
                 end
             end
         end
     end
 
+    @testset "unflatten type stability" begin
+        @model function demo(y)
+            x ~ Normal()
+            y ~ Normal(x, 1)
+            return nothing
+        end
+
+        model = demo(0.0)
+        varinfos = DynamicPPL.TestUtils.setup_varinfos(
+            model, (; x=1.0), (@varname(x),); include_threadsafe=true
+        )
+        @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
+            # Skip the severely inconcrete `SimpleVarInfo` types, since checking for type
+            # stability for them doesn't make much sense anyway.
+            if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} ||
+                varinfo isa
+               DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}}
+                continue
+            end
+            @inferred DynamicPPL.unflatten(varinfo, varinfo[:])
+        end
+    end
+
     @testset "subset" begin
         @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV}
             s ~ InverseGamma(2, 3)
@@ -941,19 +1080,19 @@ end
 
         # First iteration, variables are added to vi
         # variables samples in order: z1,a1,z2,a2,z3
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z1, dists[1])
         randr(vi, vn_a1, dists[2])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_b, dists[2])
         randr(vi, vn_z2, dists[1])
         randr(vi, vn_a2, dists[2])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z3, dists[1])
         @test vi.metadata.orders == [1, 1, 2, 2, 2, 3]
         @test DynamicPPL.get_num_produce(vi) == 3
 
-        DynamicPPL.reset_num_produce!(vi)
+        vi = DynamicPPL.reset_num_produce!!(vi)
         DynamicPPL.set_retained_vns_del!(vi)
         @test DynamicPPL.is_flagged(vi, vn_z1, "del")
         @test DynamicPPL.is_flagged(vi, vn_a1, "del")
@@ -961,12 +1100,12 @@ end
         @test DynamicPPL.is_flagged(vi, vn_a2, "del")
         @test DynamicPPL.is_flagged(vi, vn_z3, "del")
 
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z1, dists[1])
         randr(vi, vn_a1, dists[2])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z2, dists[1])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z3, dists[1])
         randr(vi, vn_a2, dists[2])
         @test vi.metadata.orders == [1, 1, 2, 2, 3, 3]
@@ -975,21 +1114,21 @@ end
         vi = empty!!(DynamicPPL.typed_varinfo(vi))
         # First iteration, variables are added to vi
         # variables samples in order: z1,a1,z2,a2,z3
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z1, dists[1])
         randr(vi, vn_a1, dists[2])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_b, dists[2])
         randr(vi, vn_z2, dists[1])
         randr(vi, vn_a2, dists[2])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z3, dists[1])
         @test vi.metadata.z.orders == [1, 2, 3]
         @test vi.metadata.a.orders == [1, 2]
         @test vi.metadata.b.orders == [2]
         @test DynamicPPL.get_num_produce(vi) == 3
 
-        DynamicPPL.reset_num_produce!(vi)
+        vi = DynamicPPL.reset_num_produce!!(vi)
         DynamicPPL.set_retained_vns_del!(vi)
         @test DynamicPPL.is_flagged(vi, vn_z1, "del")
         @test DynamicPPL.is_flagged(vi, vn_a1, "del")
@@ -997,12 +1136,12 @@ end
         @test DynamicPPL.is_flagged(vi, vn_a2, "del")
         @test DynamicPPL.is_flagged(vi, vn_z3, "del")
 
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z1, dists[1])
         randr(vi, vn_a1, dists[2])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z2, dists[1])
-        DynamicPPL.increment_num_produce!(vi)
+        vi = DynamicPPL.increment_num_produce!!(vi)
         randr(vi, vn_z3, dists[1])
         randr(vi, vn_a2, dists[2])
         @test vi.metadata.z.orders == [1, 2, 3]
@@ -1017,8 +1156,8 @@ end
 
         n = length(varinfo[:])
         # `Bool`.
-        @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1))
+        @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1))
         # `Int`.
-        @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1))
+        @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1))
     end
 end
diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl
index bd3f5553f..f21d458a8 100644
--- a/test/varnamedvector.jl
+++ b/test/varnamedvector.jl
@@ -607,7 +607,7 @@ end
                 DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext())
             )
             # Log density should be the same.
-            @test getlogp(varinfo_eval) ≈ logp_true
+            @test getlogjoint(varinfo_eval) ≈ logp_true
             # Values should be the same.
             DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns)
 
@@ -616,7 +616,7 @@ end
                 DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext())
             )
             # Log density should be different.
-            @test getlogp(varinfo_sample) != getlogp(varinfo)
+            @test getlogjoint(varinfo_sample) != getlogjoint(varinfo)
             # Values should be different.
             DynamicPPL.TestUtils.test_values(
                 varinfo_sample, value_true, vns; compare=!isequal