Skip to content

Commit

Permalink
Merge pull request #101 from dfdx/more-dynamic-rrule-via-ad
Browse files Browse the repository at this point in the history
More dynamic rrule via ad
  • Loading branch information
dfdx authored Jan 11, 2022
2 parents 2ad0d21 + 532308a commit 5583700
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.6.2"
version = "0.6.3"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Expand Down
10 changes: 7 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,15 @@ function ChainRulesCore.rrule_via_ad(::YotaRuleConfig, f, args...)
sig = call_signature(f, args...)
if haskey(GENERATED_RRULE_CACHE, sig)
rr = GENERATED_RRULE_CACHE[sig]
return Base.invokelatest(rr, f, args...)
# return Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, f, args...)
return val, dy -> Base.invokelatest(pb, dy)
else
rr = make_rrule(f, args...)
GENERATED_RRULE_CACHE[sig] = rr
return Base.invokelatest(rr, f, args...)
# return Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, f, args...)
return val, dy -> Base.invokelatest(pb, dy)
end
end

Expand Down Expand Up @@ -194,7 +198,7 @@ end

function ChainRulesCore.rrule(::typeof(tuple), args...)
y = tuple(args...)
return y, dy -> (NoTangent(), collect(dy...)...)
return y, dy -> (NoTangent(), collect(dy)...)
end

# test_rrule(tuple, 1, 2, 3; output_tangent=Tangent{Tuple}((1, 2, 3)), check_inferred=false)
Expand Down

0 comments on commit 5583700

Please sign in to comment.