@@ -4,28 +4,57 @@ using ADTypes: AbstractADType, AutoForwardDiff
44using Chairmarks: @be
55import DifferentiationInterface as DI
66using DocStringExtensions
7- using DynamicPPL:
8- Model,
9- LogDensityFunction,
10- VarInfo,
11- AbstractVarInfo,
12- link,
13- DefaultContext,
14- AbstractContext
7+ using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
158using LogDensityProblems: logdensity, logdensity_and_gradient
16- using Random: Random, Xoshiro
9+ using Random: AbstractRNG, default_rng
1710using Statistics: median
1811using Test: @test
1912
20- export ADResult, run_ad, ADIncorrectException
13+ export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest
2114
2215"""
23- REFERENCE_ADTYPE
16+ AbstractADCorrectnessTestSetting
2417
25- Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
26- it's the default AD backend used in Turing.jl.
18+ Different ways of testing the correctness of an AD backend.
2719"""
28- const REFERENCE_ADTYPE = AutoForwardDiff ()
20+ abstract type AbstractADCorrectnessTestSetting end
21+
22+ """
23+ WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting
24+
25+ Test correctness by comparing it against the result obtained with `adtype`.
26+
27+ `adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in
28+ Turing.jl.
29+ """
30+ struct WithBackend{AD<: AbstractADType } <: AbstractADCorrectnessTestSetting
31+ adtype:: AD
32+ end
33+ WithBackend () = WithBackend (AutoForwardDiff ())
34+
35+ """
36+ WithExpectedResult(
37+ value::T,
38+ grad::AbstractVector{T}
39+ ) where {T <: AbstractFloat}
40+ <: AbstractADCorrectnessTestSetting
41+
42+ Test correctness by comparing it against a known result (e.g. one obtained
43+ analytically, or one obtained with a different backend previously). Both the
44+ value of the primal (i.e. the log-density) as well as its gradient must be
45+ supplied.
46+ """
47+ struct WithExpectedResult{T<: AbstractFloat } <: AbstractADCorrectnessTestSetting
48+ value:: T
49+ grad:: AbstractVector{T}
50+ end
51+
52+ """
53+ NoTest() <: AbstractADCorrectnessTestSetting
54+
55+ Disable correctness testing.
56+ """
57+ struct NoTest <: AbstractADCorrectnessTestSetting end
2958
3059"""
3160 ADIncorrectException{T<:AbstractFloat}
@@ -45,17 +74,18 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
4574end
4675
4776"""
48- ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
77+ ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat }
4978
5079Data structure to store the results of the AD correctness test.
5180
5281The type parameter `Tparams` is the numeric type of the parameters passed in;
53- `Tresult` is the type of the value and the gradient.
82+ `Tresult` is the type of the value and the gradient; and `Ttol` is the type of the
83+ absolute and relative tolerances used for correctness testing.
5484
5585# Fields
5686$(TYPEDFIELDS)
5787"""
58- struct ADResult{Tparams<: AbstractFloat ,Tresult<: AbstractFloat }
88+ struct ADResult{Tparams<: AbstractFloat ,Tresult<: AbstractFloat ,Ttol <: AbstractFloat }
5989 " The DynamicPPL model that was tested"
6090 model:: Model
6191 " The VarInfo that was used"
@@ -64,18 +94,18 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
6494 params:: Vector{Tparams}
6595 " The AD backend that was tested"
6696 adtype:: AbstractADType
67- " The absolute tolerance for the value of logp "
68- value_atol :: Tresult
69- " The absolute tolerance for the gradient of logp "
70- grad_atol :: Tresult
97+ " Absolute tolerance used for correctness test "
98+ atol :: Ttol
99+ " Relative tolerance used for correctness test "
100+ rtol :: Ttol
71101 " The expected value of logp"
72102 value_expected:: Union{Nothing,Tresult}
73103 " The expected gradient of logp"
74104 grad_expected:: Union{Nothing,Vector{Tresult}}
75105 " The value of logp (calculated using `adtype`)"
76- value_actual:: Union{Nothing, Tresult}
106+ value_actual:: Tresult
77107 " The gradient of logp (calculated using `adtype`)"
78- grad_actual:: Union{Nothing, Vector{Tresult} }
108+ grad_actual:: Vector{Tresult}
79109 " If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
80110 time_vs_primal:: Union{Nothing,Tresult}
81111end
84114 run_ad(
85115 model::Model,
86116 adtype::ADTypes.AbstractADType;
87- test=true ,
117+ test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend() ,
88118 benchmark=false,
89- value_atol =1e-6 ,
90- grad_atol=1e-6 ,
119+ atol::AbstractFloat =1e-8 ,
120+ rtol::AbstractFloat=sqrt(eps()) ,
91121 varinfo::AbstractVarInfo=link(VarInfo(model), model),
92122 params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
93- reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
94- expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
95123 verbose=true,
96124 )::ADResult
97125
@@ -133,8 +161,8 @@ Everything else is optional, and can be categorised into several groups:
133161
134162 Note that if the VarInfo is not specified (and thus automatically generated)
135163 the parameters in it will have been sampled from the prior of the model. If
136- you want to seed the parameter generation, the easiest way is to pass a
137- `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`) .
164+ you want to seed the parameter generation for the VarInfo, you can pass the
165+ `rng` keyword argument, which will then be used to create the VarInfo .
138166
139167 Finally, note that these only reflect the parameters used for _evaluating_
140168 the gradient. If you also want to control the parameters used for
@@ -143,25 +171,35 @@ Everything else is optional, and can be categorised into several groups:
143171 prep_params)`. You could then evaluate the gradient at a different set of
144172 parameters using the `params` keyword argument.
145173
146- 3. _How to specify the results to compare against._ (Only if `test=true`.)
174+ 3. _How to specify the results to compare against._
147175
148176 Once logp and its gradient has been calculated with the specified `adtype`,
149- it must be tested for correctness.
177+ it can optionally be tested for correctness. The exact way this is tested
178+ is specified in the `test` parameter.
179+
180+ There are several options for this:
150181
151- This can be done either by specifying `reference_adtype`, in which case logp
152- and its gradient will also be calculated with this reference in order to
153- obtain the ground truth; or by using `expected_value_and_grad`, which is a
154- tuple of `(logp, gradient)` that the calculated values must match. The
155- latter is useful if you are testing multiple AD backends and want to avoid
156- recalculating the ground truth multiple times.
182+ - You can explicitly specify the correct value using
183+ [`WithExpectedResult()`](@ref).
184+ - You can compare against the result obtained with a different AD backend
185+ using [`WithBackend(adtype)`](@ref).
186+ - You can disable testing by passing [`NoTest()`](@ref).
187+ - The default is to compare against the result obtained with ForwardDiff,
188+ i.e. `WithBackend(AutoForwardDiff())`.
189+ - `test=false` and `test=true` are synonyms for
190+ `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
157191
158- The default reference backend is ForwardDiff. If none of these parameters are
159- specified, ForwardDiff will be used to calculate the ground truth.
192+ 4. _How to specify the tolerances._ (Only if testing is enabled.)
160193
161- 4. _How to specify the tolerances._ (Only if `test=true`.)
194+ Both absolute and relative tolerances can be specified using the `atol` and
195+ `rtol` keyword arguments respectively. The behaviour of these is similar to
196+ `isapprox()`, i.e. the value and gradient are considered correct if either
197+ atol or rtol is satisfied. The default values are `100*eps()` for `atol` and
198+ `sqrt(eps())` for `rtol`.
162199
163- The tolerances for the value and gradient can be set using `value_atol` and
164- `grad_atol`. These default to 1e-6.
200+ For the most part, it is the `rtol` check that is more meaningful, because
201+ we cannot know the magnitude of logp and its gradient a priori. The `atol`
202+ value is supplied to handle the case where gradients are equal to zero.
165203
1662045. _Whether to output extra logging information._
167205
@@ -180,48 +218,58 @@ thrown as-is.
180218function run_ad (
181219 model:: Model ,
182220 adtype:: AbstractADType ;
183- test:: Bool = true ,
221+ test:: Union{AbstractADCorrectnessTestSetting, Bool} = WithBackend () ,
184222 benchmark:: Bool = false ,
185- value_atol:: AbstractFloat = 1e-6 ,
186- grad_atol:: AbstractFloat = 1e-6 ,
187- varinfo:: AbstractVarInfo = link (VarInfo (model), model),
223+ atol:: AbstractFloat = 100 * eps (),
224+ rtol:: AbstractFloat = sqrt (eps ()),
225+ rng:: AbstractRNG = default_rng (),
226+ varinfo:: AbstractVarInfo = link (VarInfo (rng, model), model),
188227 params:: Union{Nothing,Vector{<:AbstractFloat}} = nothing ,
189- reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
190- expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}} = nothing ,
191228 verbose= true ,
192229):: ADResult
230+ # Convert Boolean `test` to an AbstractADCorrectnessTestSetting
231+ if test isa Bool
232+ test = test ? WithBackend () : NoTest ()
233+ end
234+
235+ # Extract parameters
193236 if isnothing (params)
194237 params = varinfo[:]
195238 end
196239 params = map (identity, params) # Concretise
197240
241+ # Calculate log-density and gradient with the backend of interest
198242 verbose && @info " Running AD on $(model. f) with $(adtype) \n "
199243 verbose && println (" params : $(params) " )
200244 ldf = LogDensityFunction (model, varinfo; adtype= adtype)
201-
202245 value, grad = logdensity_and_gradient (ldf, params)
246+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
203247 grad = collect (grad)
204248 verbose && println (" actual : $((value, grad)) " )
205249
206- if test
207- # Calculate ground truth to compare against
208- value_true, grad_true = if expected_value_and_grad === nothing
209- ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
210- logdensity_and_gradient (ldf_reference, params)
211- else
212- expected_value_and_grad
250+ # Test correctness
251+ if test isa NoTest
252+ value_true = nothing
253+ grad_true = nothing
254+ else
255+ # Get the correct result
256+ if test isa WithExpectedResult
257+ value_true = test. value
258+ grad_true = test. grad
259+ elseif test isa WithBackend
260+ ldf_reference = LogDensityFunction (model, varinfo; adtype= test. adtype)
261+ value_true, grad_true = logdensity_and_gradient (ldf_reference, params)
262+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
263+ grad_true = collect (grad_true)
213264 end
265+ # Perform testing
214266 verbose && println (" expected : $((value_true, grad_true)) " )
215- grad_true = collect (grad_true)
216-
217267 exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
218- isapprox (value, value_true; atol= value_atol) || exc ()
219- isapprox (grad, grad_true; atol= grad_atol) || exc ()
220- else
221- value_true = nothing
222- grad_true = nothing
268+ isapprox (value, value_true; atol= atol, rtol= rtol) || exc ()
269+ isapprox (grad, grad_true; atol= atol, rtol= rtol) || exc ()
223270 end
224271
272+ # Benchmark
225273 time_vs_primal = if benchmark
226274 primal_benchmark = @be (ldf, params) logdensity (_[1 ], _[2 ])
227275 grad_benchmark = @be (ldf, params) logdensity_and_gradient (_[1 ], _[2 ])
@@ -237,8 +285,8 @@ function run_ad(
237285 varinfo,
238286 params,
239287 adtype,
240- value_atol ,
241- grad_atol ,
288+ atol ,
289+ rtol ,
242290 value_true,
243291 grad_true,
244292 value,
0 commit comments