diff --git a/Project.toml b/Project.toml index 25eae2b..29b0358 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Yota" uuid = "cd998857-8626-517d-b929-70ad188a48f0" authors = ["Andrei Zhabinski "] -version = "0.7.2" +version = "0.7.3" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" diff --git a/src/cr_api.jl b/src/cr_api.jl index 0b86655..ee7a350 100644 --- a/src/cr_api.jl +++ b/src/cr_api.jl @@ -10,13 +10,10 @@ import Umlaut: make_name, Input, to_expr, BcastCtx struct ChainRulesCtx end -@inline instance_type(f::F) where {F} = F -@inline instance_type(T::UnionAll) = Type{<:T} -@inline instance_type(T::DataType) = Type{T} function isprimitive(::ChainRulesCtx, f, args...) - F = instance_type(f) - Args = instance_type.(args) + F = Core.Typeof(f) + Args = Core.Typeof.(args) Core.Compiler.return_type(rrule, Tuple{YotaRuleConfig, F, Args...}) !== Nothing && return true if is_kwfunc(F) Args_kwrrule = Tuple{Any, typeof(Core.kwfunc(f)), YotaRuleConfig, Args...,} diff --git a/src/grad.jl b/src/grad.jl index 85a8dbf..6fdd7e7 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -49,7 +49,7 @@ function record_primitive!(tape::Tape{GradCtx}, v_fargs...) rr_op = (is_kwfunc(f) ? mkcall(Core.kwfunc(rrule), v_args[1], rrule, YOTA_RULE_CONFIG, v_args[2:end]...) : mkcall(rrule, YOTA_RULE_CONFIG, v_f, v_args...)) - @assert rr_op.val !== nothing "rrule($(op.fn), ...) returned nothing" + @assert rr_op.val !== nothing "rrule($f, ...) returned nothing" v_rr = push!(tape, rr_op) v_val = push!(tape, mkcall(_getfield, v_rr, 1)) v_pb = push!(tape, mkcall(_getfield, v_rr, 2)) @@ -70,6 +70,11 @@ struct BcastGradCtx end +# get_static_params is broken for BcastGradCtx, so turning off +# this feature for now +Umlaut.get_static_params(::Tracer{BcastGradCtx}, v_fargs) = Core.svec([]) + + function record_or_recurse!(t::Tracer{BcastGradCtx}, v_fargs...) fargs = [v isa V ? t.tape[v].val : v for v in v_fargs] # global STATE = (t, v_fargs) diff --git a/test/test_cr_api.jl b/test/test_cr_api.jl index a44b610..5be2f5c 100644 --- a/test/test_cr_api.jl +++ b/test/test_cr_api.jl @@ -29,6 +29,8 @@ rrule(::YotaRuleConfig, ::typeof(primitive_test2), x; y=1) = primitive_test2(x; @test val == 7 @test pb(1.0) == (ZeroTangent(), 2.0) + trace(broadcasted, double_dec, [1.0, 2.0]) + rr = make_rrule(broadcasted, double_dec, [1.0, 2.0]) val, pb = rr(config, broadcasted, double_dec, [3.0, 4.0]) @test val == [5.0, 7.0]