Skip to content

Commit

Permalink
Add wip rrule importer
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 12, 2024
1 parent ac772c9 commit 769f686
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 3 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ version = "0.12.5"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Expand Down
106 changes: 105 additions & 1 deletion ext/EnzymeChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,108 @@ function Enzyme._import_frule(fn, tys...)
end


end # module
"""
import_rrule(::fn, tys...)
Automatically import a ChainRules.rrule as a custom reverse mode EnzymeRule. When called in batch mode, this
will end up calling the primal multiple times which results in slower code. This macro assumes that the underlying
function to be imported is read-only, and returns a Duplicated or Const object. This macro also assumes that the
inputs permit a .+= operation and that the output has a valid Enzyme.make_zero function defined. It also assumes
that overwritten(x) accurately describes if there is any non-preserved data from forward to reverse, not just
the outermost data structure being overwritten as provided by the specification.
Finally, this macro falls back to almost always caching all of the inputs, even if it may not be needed for the
derivative computation.
As a result, this auto importer is also likely to be slower than writing your own rule, and may also be slower
than not having a rule at all.
Use with caution.
```
Enzyme.@import_rrule(typeof(Base.sort), Any);
```
"""
macro import_rrule(fn, tys...)
vals = []
valtys = []
exprs = []
primals = []
tangents = []
tangentsi = []
anns = []
nothings = [(:nothing)]
for (i, ty) in enumerate(tys)
push!(nothings, :(nothing))
val = Symbol("arg_$i")
TA = Symbol("AN_$i")
e = :($val::$TA)
push!(anns, :($TA <: Union{Const, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed}{<:$ty}))
push!(vals, val)
push!(exprs, e)
primal = Symbol("primcopy_$i")
push!(primals, primal)
push!(valtys, :($primal = overwritten(config)[$i+1] ? deepcopy($val.val) : $val.val))
push!(tangents, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval))
push!(tangentsi, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval[i]))
end

:(
function EnzymeRules.augmented_primal(config::ConfigWidth{batchsize}, fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {batchsize, RetAnnotation, FA<:Annotation{<:$fn}, $(anns...)}
$(valtys...)

res, pullback = $ChainRulesCore.rrule(fn.val, $(primals...); kwargs...)

primal = if needs_primal(config)
res
else
nothing
end

shadow = if !needs_shadow(config)
nothing
else
if batchsize == 1
Enzyme.make_zero(res)
else
ntuple(Val(batchsize)) do j
Base.@_inline_meta
Enzyme.make_zero(res)
end
end
end

return AugmentedReturn(primal, shadow, (shadow, pullback))
end

function EnzymeRules.reverse(config::ConfigWidth{batchsize}, fn::FA, ::Type{RetAnnotation}, tape::TapeTy, $(exprs...); kwargs...) where {batchsize, RetAnnotation, TapeTy, FA<:Annotation{<:$fn}, $(anns...)}
shadow, pullback = tape

if batchsize == 1
res = pullback(shadow)
for (cr, en) in zip(res, (fn, $(vals...),))
if en isa Const || cr <: $ChainRulesCore.NoTangent
continue
end
en.dval .+= cr
end
else
ntuple(Val(batchsize)) do i
Base.@_inline_meta
res = pullback(shadow[i])
for (cr, en) in zip(res, (fn, $(vals...),))
if en isa Const || cr <: $ChainRulesCore.NoTangent
continue
end
en.dval[i] .+= cr
end
nothing
end
end

return ($(nothings...),)
end
)
end

end # module
6 changes: 6 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1231,4 +1231,10 @@ macro import_frule(args...)
return _import_frule(args...)
end

function _import_rrule end # defined in EnzymeChainRulesCoreExt extension

macro import_rrule(args...)
return _import_rrule(args...)
end

end # module
61 changes: 61 additions & 0 deletions test/ext/chainrulescore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,67 @@ fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2]
end
end

rdiff(f, x::Number) = autodiff(Reverse, f, Active, Active(x))[1][1]

@testset "import_rrule" begin
f1(x) = 2*x
ChainRulesCore.@scalar_rule f1(x) (5*one(x),)
Enzyme.@import_frule typeof(f1) Any
@test rdiff(f1, 1f0) === 5f0
@test rdiff(f1, 1.0) === 5.0

# specific signature
f2(x) = 2*x
ChainRulesCore.@scalar_rule f2(x) (5*one(x),)
Enzyme.@import_frule typeof(f2) Float32
@test rdiff(f2, 1f0) === 5f0
@test rdiff(f2, 1.0) === 2.0

# two arguments
f3(x, y) = 2*x + y
ChainRulesCore.@scalar_rule f3(x, y) (5*one(x), y)
Enzyme.@import_frule typeof(f3) Any Any
@test rdiff(x -> f3(x, 1.0), 2.) === 5.0
@test rdiff(y -> f3(1.0, y), 2.) === 2.0

@testset "batch duplicated" begin
x = [1.0, 2.0, 0.0]
Enzyme.@import_rrule typeof(Base.sort) Any

test_reverse(Base.sort, Duplicated, (x, Duplicated))
# Unsupported by EnzymeTestUtils
# test_reverse(Base.sort, Duplicated, (x, DuplicatedNoNeed))
test_reverse(Base.sort, DuplicatedNoNeed, (x, Duplicated))
# Unsupported by EnzymeTestUtils
# test_reverse(Base.sort, DuplicatedNoNeed, (x, DuplicatedNoNeed))
test_reverse(Base.sort, Const, (x, Duplicated))
# Unsupported by EnzymeTestUtils
# test_reverse(Base.sort, Const, (x, DuplicatedNoNeed))

test_reverse(Base.sort, Const, (x, Const))

# ChainRules does not support this case (returning notangent)
# test_reverse(Base.sort, Duplicated, (x, Const))
# test_reverse(Base.sort, DuplicatedNoNeed, (x, Const))

test_reverse(Base.sort, BatchDuplicated, (x, BatchDuplicated))
# Unsupported by EnzymeTestUtils
# test_reverse(Base.sort, BatchDuplicated, (x, BatchDuplicatedNoNeed))
test_reverse(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated))
# Unsupported by EnzymeTestUtils
# test_reverse(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicatedNoNeed))
test_reverse(Base.sort, Const, (x, BatchDuplicated))
# Unsupported by EnzymeTestUtils
# test_reverse(Base.sort, Const, (x, BatchDuplicatedNoNeed))

# ChainRules does not support this case (returning notangent)
# test_reverse(Base.sort, BatchDuplicated, (x, Const))
# test_reverse(Base.sort, BatchDuplicatedNoNeed, (x, Const))
end
end






Expand Down

0 comments on commit 769f686

Please sign in to comment.