Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 7, 2025

Much like how #1113 makes LogDensityFunction much faster by simply removing the use of Metadata, this PR makes InitContext much faster by also cutting out Metadata.

I mainly wanted to put this out there so that we have something to discuss on Monday's meeting.

Benchmarks

Here are some benchmarks on the trivial model. Obviously, this tiny model makes this PR look really good. But given that the speedups for NT and Dict are on the same scale as for Vector, I assume that it will generalise to other models in exactly the same way as discussed previously.

using DynamicPPL, Distributions, Chairmarks, LogDensityProblems
@model f() = x ~ Normal()
model = f()
vi = VarInfo(model)
_, vi = DynamicPPL.init!!(model, vi, InitFromParams((; x = 0.0)))
oavi = DynamicPPL.Experimental.OnlyAccsVarInfo(vi.accs)
ldf = DynamicPPL.LogDensityFunction(model, getlogjoint)
fldf = DynamicPPL.Experimental.FastLDF(model, getlogjoint)
xvec = [5.0]
xnt = InitFromParams((x = 5.0,))
xdict = InitFromParams(Dict(@varname(x) => 5.0))

##### Vector #####

# current slowLDF.
@be LogDensityProblems.logdensity($ldf, $xvec)
# 157.736 ns (6 allocs: 192 bytes)

# FastLDF.
@be LogDensityProblems.logdensity($fldf, $xvec)
# 10.939 ns

##### NamedTuple #####

# current InitFromParams(NT) with a full VarInfo.
@be getlogjoint(last(DynamicPPL.init!!($model, $vi, $xnt)))
# 72.297 ns (4 allocs: 128 bytes)

# the equivalent of FastNamedTupleLDF.
@be getlogjoint(last(DynamicPPL.init!!($model, $oavi, $xnt)))
# 4.033 ns

# current LDF with SimpleVarInfo(NT). Note that this isn't really the same thing
# as FastNamedTupleLDF because it still uses a vector for evaluation.
svi = SimpleVarInfo((; x = 0.0))
ldf_svi = DynamicPPL.LogDensityFunction(model, getlogjoint, svi)
@be LogDensityProblems.logdensity($ldf_svi, $xvec)
# 4.345 ns

##### Dict #####

# current InitFromParams(Dict) with a full VarInfo.
@be getlogjoint(last(DynamicPPL.init!!($model, $vi, $xdict)))
# 123.065 ns (4 allocs: 128 bytes)

# the equivalent of FastDictLDF.
@be getlogjoint(last(DynamicPPL.init!!($model, $oavi, $xdict)))
# 13.253 ns

# current LDF with SimpleVarInfo(Dict).
dsvi = SimpleVarInfo(Dict(@varname(x) => 0.0))
ldf_dsvi = DynamicPPL.LogDensityFunction(model, getlogjoint, dsvi)
@be LogDensityProblems.logdensity($ldf_dsvi, $xvec)
# 117.060 ns (6 allocs: 384 bytes)

AD

On top of that, you can use this to run faster AD on NamedTuple inputs.

import DifferentiationInterface as DI
using Mooncake, Enzyme

function namedtuple_logjoint(xnt)
    return getlogjoint(last(DynamicPPL.init!!(model, oavi, InitFromParams(xnt))))
end

prep = DI.prepare_gradient(namedtuple_logjoint, AutoMooncake(), (; x = 1.0))
@be DI.gradient(namedtuple_logjoint, prep, AutoMooncake(), (; x = 1.0))
# (x = -1.0,)
# 3.088 μs (65 allocs: 1.922 KiB)

ad = AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const)
prep = DI.prepare_gradient(namedtuple_logjoint, ad, (; x = 1.0))
@be DI.gradient(namedtuple_logjoint, prep, ad, (; x = 1.0))
# (x = -1.0,)
# 4.535 μs (90 allocs: 3.391 KiB)

(This isn't new, you always could do it: this is just faster. Using vi instead of oavi makes Mooncake about 1.6x slower.)

Note that this effectively provides a 'fast' implementation of NamedLogDensity (#880) although it should be mentioned that vector LogDensityFunction is still a lot faster for both Mooncake and Enzyme.

(ForwardDiff and ReverseDiff only work with vector-valued inputs.)

Outlook

I think the main thing that I learnt from this is that fast NamedTuple and Dict evaluation is already quite easily obtainable (this PR is < 10 lines) by composing InitFromParams with OnlyAccsVarInfo.

This makes me think that what's now called FastEvalVectorContext can actually just be renamed into another flavour of InitFromParams. That would, IMO, lead to an incredibly satisfying reduction of code complexity, and conceptually I find it extremely clear:

  1. InitFromParams gives you SOME kind of way to obtain params. It might be from a NT, or a Dict, or a vector where you know which sub-range belongs to which param.
  2. OnlyAccsVarInfo only contains accs, and you can use that for 'fast evaluation' as long as the accs are the only thing you need.

So the combination of InitFromParams + OnlyAccsVarInfo gives you a clear-cut way of getting accs given params.

This also means that generalising 'fast evaluation' to other types of parameters is incredibly easy: just implement a new method for InitFromParams and voila, it will straightaway work with OnlyAccsVarInfo.

The only thing I'm not in love with is that init!! starts to look a bit like a misnomer. Sure, I guess, it's still initialising accumulators for you. But I think this reinforces the notion I've had for a while now, that InitFromParams was actually about way more than just filling a VarInfo with specific params: the truth is that InitFromParams is actually a way to completely decouple parameter values from VarInfo.

@penelopeysm penelopeysm changed the base branch from main to py/fastldf November 7, 2025 17:31
@penelopeysm penelopeysm mentioned this pull request Nov 7, 2025
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
)
return InitFromParams(to_varname_dict(params), fallback)
return new{typeof(params),typeof(fallback)}(params, fallback)
Copy link
Member Author

@penelopeysm penelopeysm Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code, which is present on main, is actually (sadly) quite bad code (written by yours truly). For the trivial model, it had the effect of making simple NamedTuples 10x slower because to_varname_dict returns a Dict{VarName,Any}, which is too loosely typed and leads to slow lookups etc.

DynamicPPL.jl/src/utils.jl

Lines 851 to 855 in 08fffa2

# Convert (x=1,) to Dict(@varname(x) => 1)
function to_varname_dict(nt::NamedTuple)
return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt))
end
to_varname_dict(d::AbstractDict) = d

@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2025

Benchmark Report for Commit 699aa23

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────┬─────────────────┐
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │ t(eval)/t(ref) │ t(grad)/t(eval) │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼────────────────┼─────────────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │            6.1 │             1.8 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │          720.7 │            45.2 │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │          412.9 │            58.1 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │          787.7 │            36.4 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │         7024.1 │            24.8 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │          750.2 │            42.2 │
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │          780.8 │            37.2 │
│           Smorgasbord │   201 │ reversediff │             typed │   true │          910.8 │            44.7 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │          727.8 │             5.9 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │          890.4 │             4.5 │
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │         3873.5 │             5.9 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │          996.0 │             9.1 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │        42353.1 │             5.6 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │         8600.8 │            10.0 │
│               Dynamic │    10 │    mooncake │             typed │   true │          123.5 │            11.3 │
│              Submodel │     1 │    mooncake │             typed │   true │            9.0 │             6.3 │
│                   LDA │    12 │ reversediff │             typed │   true │          989.5 │             2.1 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴────────────────┴─────────────────┘

@codecov
Copy link

codecov bot commented Nov 7, 2025

Codecov Report

❌ Patch coverage is 12.50000% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.72%. Comparing base (4ec0c72) to head (699aa23).

Files with missing lines Patch % Lines
src/fasteval.jl 0.00% 7 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##           py/fastldf    #1125      +/-   ##
==============================================
- Coverage       81.86%   81.72%   -0.15%     
==============================================
  Files              41       41              
  Lines            3949     3956       +7     
==============================================
  Hits             3233     3233              
- Misses            716      723       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 8, 2025

DynamicPPL.jl documentation for PR #1125 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1125/

@penelopeysm penelopeysm marked this pull request as ready for review November 10, 2025 19:28
@penelopeysm penelopeysm merged commit 8715446 into py/fastldf Nov 10, 2025
3 of 17 checks passed
@penelopeysm penelopeysm deleted the py/fastinit branch November 10, 2025 19:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants