Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThreadSafeVarInfo(::SimpleVarInfo) bug with ForwardDiff #524

Closed
marius311 opened this issue Aug 30, 2023 · 1 comment
Closed

ThreadSafeVarInfo(::SimpleVarInfo) bug with ForwardDiff #524

marius311 opened this issue Aug 30, 2023 · 1 comment

Comments

@marius311
Copy link

marius311 commented Aug 30, 2023

Have had a few users report this error when running my package MuseInference which I believe boils down to a bug with SimpleVarInfo+threads+ForwardDiff. The problem is the logp vector in ThreadSafeVarInfo is concretely typed for Float64 (or whatever precision) so ForwardDiff.Dual's can't go in it.

Here's a MWE (DynamicPPL v0.23.13, Turing v0.28.2) which you must run with --threads=2 (or more) to trigger the error, otherwise no error:

using Turing, ForwardDiff
import DynamicPPL as DynPPL

@model function foo()
    x ~ Normal()
end

model = foo()

ForwardDiff.derivative(1.) do x
    vi = DynPPL.SimpleVarInfo((;x), 0., DynPPL.NoTransformation())
    DynPPL.logprior(model, vi)
end

(stack trace below)

One solution is relax the logp type (although it makes things unstable, but maybe it doesnt matter, not sure what you guys would want as a fix)

function DynPPL.ThreadSafeVarInfo(vi::DynPPL.SimpleVarInfo)
    return DynPPL.ThreadSafeVarInfo(vi, Vector{Real}(zeros(typeof(DynPPL.getlogp(vi)), Threads.nthreads())))
end

If this is a bug would be great to get fixed. If this is instead a problem with how I'm using DynamicPPL internals please let me know.

ERROR: InexactError: Int(Int64, Dual{ForwardDiff.Tag{var"#9#10", Float64}}(-1.4189385332046727,-1.0))
Stacktrace:
  [1] Int64
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:364 [inlined]
  [2] convert
    @ ./number.jl:7 [inlined]
  [3] setindex!
    @ ./array.jl:969 [inlined]
  [4] acclogp!!
    @ ~/.julia/packages/DynamicPPL/W8pRQ/src/threadsafe.jl:25 [inlined]
  [5] tilde_assume!!(context::DynamicPPL.PriorContext{Nothing}, right::Normal{Float64}, vn::AbstractPPL.VarName{:x, Setfield.IdentityLens}, vi::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.SimpleVarInfo{NamedTuple{(:x,), Tuple{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1}}}, Int64, DynamicPPL.NoTransformation}, Vector{Int64}})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/W8pRQ/src/context_implementations.jl:118
  [6] macro expansion
    @ ~/.julia/packages/DynamicPPL/W8pRQ/src/compiler.jl:555 [inlined]
  [7] foo
    @ ./REPL[3]:1 [inlined]
  [8] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/W8pRQ/src/model.jl:963 [inlined]
  [9] evaluate_threadsafe!!(model::DynamicPPL.Model{typeof(foo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, varinfo::DynamicPPL.SimpleVarInfo{NamedTuple{(:x,), Tuple{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1}}}, Int64, DynamicPPL.NoTransformation}, context::DynamicPPL.PriorContext{Nothing})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/W8pRQ/src/model.jl:952
 [10] evaluate!!
    @ ~/.julia/packages/DynamicPPL/W8pRQ/src/model.jl:887 [inlined]
 [11] logprior(model::DynamicPPL.Model{typeof(foo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.DefaultContext}, varinfo::DynamicPPL.SimpleVarInfo{NamedTuple{(:x,), Tuple{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1}}}, Int64, DynamicPPL.NoTransformation})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/W8pRQ/src/model.jl:1106
 [12] (::var"#9#10")(x::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 1})
    @ Main ./REPL[5]:3
 [13] derivative(f::var"#9#10", x::Float64)
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/derivative.jl:14
 [14] top-level scope
    @ REPL[5]:1
@marius311
Copy link
Author

Actually after digging some more I'm guessing I'm using the internals wrong and I should follow what the other constructors do which is to use float_type_with_fallback and infer_nested_eltype to get the right typed zero:

SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant