-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MethodError: no method matching length(::Type{Val{2}}) when differentiating log-likelihood #121
Comments
FWIW the error is slightly different on master: julia> import Yota
julia> normal_pdf(x, mean, var) = exp(-(x - mean)^2 / (2var)) / sqrt(2π * var);
julia> Yota.grad((x, mu) -> sum(log, normal_pdf.(x, mu, 1.0)), rand(10), 1.0)
┌ Error: Failed to compile rrule for broadcasted(normal_pdf, [0.8031553730805592, 0.3509560552123825, 0.032551822966513155, 0.21170603638555385, 0.2049628078398853, 0.8469464153815023, 0.4220217037413583, 0.286621858419426, 0.0338940405059448, 0.43421951685956195], 1.0, 1.0), extract details via:
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
└ @ Yota ~/.julia/packages/Yota/QGPcM/src/cr_api.jl:179
ERROR: MethodError: no method matching iterate(::Type{Val})
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen})
@ Base range.jl:869
iterate(::Union{LinRange, StepRangeLen}, ::Integer)
@ Base range.jl:869
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
@ Base dict.jl:698
...
Stacktrace:
[1] first(itr::Type)
@ Base ./abstractarray.jl:436
[2] map(f::typeof(first), t::Tuple{UnionAll, Int64})
@ Base ./tuple.jl:274
[3] trace_call!(::Umlaut.Tracer{Yota.BcastGradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:92
[4] trace_block!(t::Umlaut.Tracer{Yota.BcastGradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:312
[5] trace!(t::Umlaut.Tracer{Yota.BcastGradCtx}, v_fargs::Vector{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:436
[6] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.BcastGradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:546
[7] make_rrule(::typeof(Base.Broadcast.broadcasted), ::Function, ::Vector{Float64}, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/QGPcM/src/cr_api.jl:136
[8] rrule_via_ad(::Yota.YotaRuleConfig, ::Function, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/QGPcM/src/cr_api.jl:172
[9] rrule(::Yota.YotaRuleConfig, ::typeof(Base.Broadcast.broadcasted), ::typeof(normal_pdf), ::Vector{Float64}, ::Float64, ::Float64)
@ Yota ~/.julia/packages/Yota/QGPcM/src/rulesets.jl:91
[10] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, line::Core.LineInfoNode, kwargs::NamedTuple{(), Tuple{}}, free_kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/tape.jl:202
[11] record_primitive!(::Umlaut.Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:54
[12] trace_call!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:286
[13] trace_block!(t::Umlaut.Tracer{Yota.GradCtx}, ir::Core.Compiler.IRCode, bi::Int64, prev_bi::Int64, sparams::Core.SimpleVector)
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:312
[14] trace!(t::Umlaut.Tracer{Yota.GradCtx}, v_fargs::Vector{Umlaut.Variable})
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:436
[15] trace(::Function, ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.GradCtx, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/.julia/packages/Umlaut/LH23t/src/trace.jl:546
[16] gradtape(::var"#3#4", ::Vector{Float64}, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64)
@ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:258
[17] grad(::var"#3#4", ::Vector{Float64}, ::Vararg{Any}; seed::Int64)
@ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:334
[18] grad(::var"#3#4", ::Vector{Float64}, ::Vararg{Any})
@ Yota ~/.julia/packages/Yota/QGPcM/src/grad.jl:326
[19] top-level scope
@ REPL[5]:1
(jl_atLVz9) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_atLVz9/Project.toml`
[92992a2b] Umlaut v0.4.2 `https://github.com/dfdx/Umlaut.jl.git#main`
[cd998857] Yota v0.7.5 `https://github.com/dfdx/Yota.jl.git#main` If I disable all rules related to broadcasting in Yota, then the error is more straightforward, and I think tells us that the rules from JuliaDiff/ChainRules.jl#644 (which ought to cover this kind of broadcasting) are not being called: julia> using Yota
julia> normal_pdf(x, mean, var) = exp(-(x - mean)^2 / (2var)) / sqrt(2π * var);
julia> Yota.grad((x, mu) -> sum(log, normal_pdf.(x, mu, 1.0)), rand(10), 1.0)
ERROR: No deriative rule found for op %5 = materialize(%4)::Vector{Float64} , try defining it using
ChainRulesCore.rrule(::typeof(Base.Broadcast.materialize), ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(normal_pdf), Tuple{Vector{Float64}, Float64, Float64}}) = ...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
@ Yota ~/.julia/dev/Yota/src/grad.jl:185
...
# same error on a much easier broadcast, should use derivatives_given_output
julia> Yota.grad((x, mu) -> sum(log, atan.(x, mu)), rand(10), 1.0)
ERROR: No deriative rule found for op %5 = materialize(%4)::Vector{Float64} , try defining it using ...
# and an even easier one, has its own rrule(broadcasted, +, ...)
julia> Yota.grad((x, mu) -> sum(log, x .+ mu), rand(10), 1.0)
ERROR: No deriative rule found for op %7 = materialize(%5)::Vector{Float64} , try defining it using ...
(jl_fINTq0) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_fINTq0/Project.toml`
[082447d4] ChainRules v1.43.2
[92992a2b] Umlaut v0.4.2 `https://github.com/dfdx/Umlaut.jl.git#main`
[cd998857] Yota v0.7.5 `~/.julia/dev/Yota` |
Oh, I think I messed up broadcasting on |
@mcabbott If I understand it correctly, JuliaDiff/ChainRules.jl#644 doesn't provide an julia> grad((x, mu) -> sum(log, x .+ mu), rand(10), 1.0)
(2.515783933810798, (ZeroTangent(), [0.5835340506097648, 0.6717831140646412, 0.9651554343975317, 0.5720477538411006, 0.9756415208897008, 0.7392737693342455, 0.9473341524270764, 0.9914851197981195, 0.9458598901263214, 0.5826024579088742], 7.974717263397377))
|
Oh right, sorry, the rule for |
The fix is now on |
Thanks! Are you thinking of making a release sometime soon? |
@mcabbott Yep, tagged v0.7.4. |
Code
Error message
Versions
The text was updated successfully, but these errors were encountered: