Skip to content

Commit

Permalink
Allow rrules with RuleConfig{>:HasReverseMode} (previously only YotaR…
Browse files Browse the repository at this point in the history
…uleConfig was allowed)
  • Loading branch information
dfdx committed Jan 9, 2022
1 parent 61f1f77 commit 2ad0d21
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,35 @@ import Ghost: make_name, Input, to_expr
# Primitives #
###############################################################################

function function_signatures(fn=rrule)
"""
Collect list of function signatures for which rrule() or no_rrule() is defined
"""
function rrule_covered_signatures(fn=rrule)
rrule_methods = methods(fn).ms
sigs = [rr.sig for rr in rrule_methods]
# remove `rrule` parameter
sigs = [remove_first_parameter(sig) for sig in sigs]
# remove YotaRuleConfig parameter if any
sigs = [Ghost.get_type_parameters(sig)[1] <: YotaRuleConfig ?
remove_first_parameter(sig) :
sig
for sig in sigs]
rrule_sigs = [rr.sig for rr in rrule_methods]
primal_sigs = []
for rr_sig in rrule_sigs
# remove `rrule` parameter
sig = remove_first_parameter(rr_sig)
Ts = collect(Ghost.get_type_parameters(sig))
# skip rules with config with features that we don't support
if Ts[1] <: RuleConfig && !(Ts[1] <: RuleConfig{>:HasReverseMode})
continue
end
# remove RuleConfig parameter
if Ts[1] <: RuleConfig{>:HasReverseMode}
sig = remove_first_parameter(sig)
end
# now sig looks like the signature of the primal function
push!(primal_sigs, sig)
end
# add keyword version of these functions as well
kw_sigs = [kwsig for kwsig in map(kwfunc_signature, sigs) if kwsig !== Tuple{}]
return [sigs; kw_sigs]
kw_sigs = [kwsig for kwsig in map(kwfunc_signature, primal_sigs) if kwsig !== Tuple{}]
return [primal_sigs; kw_sigs]
end



const CHAINRULES_PRIMITIVES = Ref(FunctionResolver{Bool}())
const NUM_CHAINRULES_METHODS = Ref{Int}(0)

Expand All @@ -30,8 +43,8 @@ function update_chainrules_primitives!(;force=false)
num_methods = length(methods(rrule)) + length(methods(no_rrule))
if force || num_methods != NUM_CHAINRULES_METHODS[]
sigs_flags = [
[sig => true for sig in function_signatures(rrule)];
[sig => false for sig in function_signatures(no_rrule)] # override rrule(sig...)
[sig => true for sig in rrule_covered_signatures(rrule)];
[sig => false for sig in rrule_covered_signatures(no_rrule)] # override rrule(sig...)
]
P = FunctionResolver{Bool}(sigs_flags)
CHAINRULES_PRIMITIVES[] = P
Expand Down

0 comments on commit 2ad0d21

Please sign in to comment.