-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathlogdensityfunction.jl
146 lines (119 loc) · 4.81 KB
/
logdensityfunction.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
LogDensityFunction
A callable representing a log density function of a `model`.
# Fields
$(FIELDS)
# Examples
```jldoctest
julia> using Distributions
julia> using DynamicPPL: LogDensityFunction, contextualize
julia> @model function demo(x)
m ~ Normal()
x ~ Normal(m, 1)
end
demo (generic function with 2 methods)
julia> model = demo(1.0);
julia> f = LogDensityFunction(model);
julia> # It implements the interface of LogDensityProblems.jl.
using LogDensityProblems
julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453
julia> LogDensityProblems.dimension(f)
1
julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
f = LogDensityFunction(model, SimpleVarInfo(model));
julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453
julia> # This also respects the context in `model`.
f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model));
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true
```
"""
struct LogDensityFunction{V,M,C}
"varinfo used for evaluation"
varinfo::V
"model used for evaluation"
model::M
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
context::C
end
# TODO: Deprecate.
function LogDensityFunction(
varinfo::AbstractVarInfo,
model::Model,
sampler::AbstractSampler,
context::AbstractContext,
)
return LogDensityFunction(varinfo, model, SamplingContext(sampler, context))
end
function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::Union{Nothing,AbstractContext}=nothing,
)
return LogDensityFunction(varinfo, model, context)
end
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
function getcontext(f::LogDensityFunction)
return f.context === nothing ? leafcontext(f.model.context) : f.context
end
"""
getmodel(f)
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) =
getmodel(LogDensityProblemsAD.parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
"""
setmodel(f, model[, adtype])
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(
f::LogDensityProblemsAD.ADGradientWrapper,
model::DynamicPPL.Model,
adtype::ADTypes.AbstractADType,
)
# TODO: Should we handle `SciMLBase.NoAD`?
# For an `ADGradientWrapper` we do the following:
# 1. Update the `Model` in the underlying `LogDensityFunction`.
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
# replacing the corresponding field with the new model won't be sufficient to obtain
# the correct gradients.
return LogDensityProblemsAD.ADgradient(
adtype, setmodel(LogDensityProblemsAD.parent(f), model)
)
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
# we need to define these annoying methods to ensure that we stay compatible with everything.
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))
_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx)
_get_indexer(ctx::SamplingContext) = ctx.sampler
_get_indexer(::IsParent, ctx::AbstractContext) = _get_indexer(childcontext(ctx))
_get_indexer(::IsLeaf, ctx::AbstractContext) = Colon()
"""
getparams(f::LogDensityFunction)
Return the parameters of the wrapped varinfo as a vector.
"""
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))]
# LogDensityProblems interface
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
context = getcontext(f)
vi_new = unflatten(f.varinfo, context, θ)
return getlogp(last(evaluate!!(f.model, vi_new, context)))
end
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
return LogDensityProblems.LogDensityOrder{0}()
end
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))