Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcast #128

Merged
merged 6 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.8.0"
version = "0.8.1"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Expand All @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Umlaut = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841"
Expand Down
77 changes: 27 additions & 50 deletions src/cr_api.jl → src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct YotaRuleConfig <: RuleConfig{Union{NoForwardsMode,HasReverseMode}} end

function to_rrule_expr(tape::Tape)
# TODO (maybe): add YotaRuleConfig() as the first argument for consistency
fn_name = gensym("rrule_$(tape[V(1)].val)")
fn_name = :(ChainRulesCore.rrule)
header = Expr(:call, fn_name)
push!(header.args, Expr(:(::), :config, YotaRuleConfig))
for v in inputs(tape)
Expand Down Expand Up @@ -83,64 +83,41 @@ Examples:
pb(1.0)

"""
make_rrule(tape::Tape) = Base.eval(@__MODULE__, to_rrule_expr(tape))
make_rrule!(tape::Tape) = Base.eval(@__MODULE__, to_rrule_expr(tape))

function make_rrule(f, args...)
function make_rrule!(f, args...)
arg_str = join(["::$(typeof(a))" for a in args], ", ")
@debug "Generating new rrule for $(f)($arg_str)"
tape = gradtape(f, args...; seed=:auto, ctx=GradCtx())
return make_rrule(tape)
make_rrule!(tape)
end

# function make_rrule(::typeof(broadcasted), f, args...)
# if isprimitive(GradCtx(), f, map(first, args)...)
# return bcast_rrule # (YOTA_RULE_CONFIG, broadcasted, f, args...)
# end
# ctx = BcastGradCtx(GradCtx())
# _, tape = trace(f, args...; ctx=ctx)
# tape = Tape(tape; ctx=ctx.inner)
# gradtape!(tape, seed=:auto)
# # insert imaginary broadcasted to the list of inputs
# insert!(tape, 1, Umlaut.Input(broadcasted))
# # insert ZeroTangent to the result to account for the additional argument
# grad_tuple_op = tape[V(tape.result.id - 2)]
# @assert grad_tuple_op isa Call && grad_tuple_op.fn == tuple
# grad_tuple_op.args = [ZeroTangent(), grad_tuple_op.args...]
# for id=grad_tuple_op.id:grad_tuple_op.id + 2
# Umlaut.exec!(tape, tape[V(id)])
# end
# return make_rrule(tape)
# end


const GENERATED_RRULE_CACHE = Dict()
const RRULE_VIA_AD_STATE = Ref{Tuple}()


"""
rrule_via_ad(::YotaRuleConfig, f, args...)

Generate `rrule` using Yota.
"""
function ChainRulesCore.rrule_via_ad(::YotaRuleConfig, f, args...)
res = rrule(f, args...)
!isnothing(res) && return res
sig = map(typeof, (f, args...))
if haskey(GENERATED_RRULE_CACHE, sig)
rr = GENERATED_RRULE_CACHE[sig]
# return Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, YOTA_RULE_CONFIG, f, args...)
return val, dy -> Base.invokelatest(pb, dy)
else
try
rr = make_rrule(f, args...)
GENERATED_RRULE_CACHE[sig] = rr
# return Base.invokelatest(rr, f, args...)
val, pb = Base.invokelatest(rr, YOTA_RULE_CONFIG, f, args...)
return val, dy -> Base.invokelatest(pb, dy)
catch
RRULE_VIA_AD_STATE[] = (f, args)
@error("Failed to compile rrule for $(f)$args, extract details via:\n" *
"\t(f, args) = Yota.RRULE_VIA_AD_STATE[]")
rethrow()
end
function ChainRulesCore.rrule_via_ad(cfg::YotaRuleConfig, f, args...)
arg_type_str = join(["::$(typeof(a))" for a in args], ", ")
@debug "Running rrule_via_ad() for $f($arg_type_str)"
res = rrule(cfg, f, args...)
if !isnothing(res)
y, pb = res
return y, pb
end
@debug "No rrule in older world ages, falling back to invokelatest"
res = Base.invokelatest(rrule, cfg, f, args...)
if !isnothing(res)
y, pb_ = res
# note: returned pullback is still in future, so we re-wrap it into invokelatest too
pb = dy -> Base.invokelatest(pb_, dy)
return y, pb
end
@debug "No rrule in the latest world age, compiling a new one"
make_rrule!(f, args...)
res = Base.invokelatest(rrule, cfg, f, args...)
y, pb_ = res
pb = dy -> Base.invokelatest(pb_, dy)
return y, pb
end
2 changes: 1 addition & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const broadcasted = Broadcast.broadcasted
include("helpers.jl")
include("utils.jl")
include("deprecated.jl")
include("cr_api.jl")
include("chainrules.jl")
include("rulesets.jl")
include("grad.jl")
include("update.jl")
Expand Down
9 changes: 6 additions & 3 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ function has_rrule(f, args...)
Args = Core.Typeof.(args)
Core.Compiler.return_type(rrule, Tuple{YotaRuleConfig, F, Args...}) !== Nothing && return true
if is_kwfunc(F)
Args_kwrrule = Tuple{Any, typeof(Core.kwfunc(f)), YotaRuleConfig, Args...,}
# must be: Tuple{Any, typeof(rrule), YotaRuleConfig, typeof(unkwfunc(f)), Args[3:end]...}
nokw_f = unkwfunc(f, args...)
Args_kwrrule = Tuple{Any, typeof(rrule), YotaRuleConfig, typeof(nokw_f), Args[3:end]...}
Core.Compiler.return_type(Core.kwfunc(rrule), Args_kwrrule) !== Nothing && return true
end
return false
Expand Down Expand Up @@ -171,6 +173,7 @@ function chainrules_transform!(tape::Tape)
while i <= length(tape)
op = tape[V(i)]
if op isa Call && isprimitive(ChainRulesCtx(), call_values(op)...)
global STATE = tape, op
# replace f(args...) with rrule(f, args...)
v_f, v_args, line = op.fn, op.args, op.line
f = op.fn isa V ? tape[op.fn].val : op.fn
Expand Down Expand Up @@ -213,7 +216,7 @@ function step_back!(tape::Tape, y::Variable)
y_fargs = is_kwfunc(rr._op.fn) ? tape[rr].args[4:end] : tape[rr].args[2:end]
else
sig_str = join(["::$T" for T in Umlaut.call_signature(tape, tape[y]).parameters], ", ")
error("No deriative rule found for op $(tape[y]), " *
error("No derivative rule found for op $(tape[y]), " *
"try defining it using \n\n\tChainRulesCore.rrule($sig_str) = ...\n")
end
for (i, x) in enumerate(y_fargs)
Expand Down Expand Up @@ -362,7 +365,7 @@ function grad(f, args...; seed=1)
cache_key = map(typeof, (f, args...))
if haskey(GRAD_CACHE, cache_key)
gf = GRAD_CACHE[cache_key]
return gf(f, args...; seed=seed)
return Base.invokelatest(gf, f, args...; seed=seed)
else
tape = gradtape(f, args...; seed=seed)
gf = grad_compile(tape)
Expand Down
27 changes: 21 additions & 6 deletions src/rulesets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,33 @@ end
# getindex, getfield, __new__ #
###############################################################################

function rrule(::YotaRuleConfig, ::typeof(getproperty), s, f::Symbol)
y = getproperty(s, f)
function rrule(::YotaRuleConfig, ::typeof(getproperty), x::T, f::Symbol) where T
y = getproperty(x, f)
proj = ProjectTo(x)
# valT = Val(T) # perhaps more stable inside closure?
function getproperty_pullback(dy)
dy = unthunk(dy)
T = typeof(s)
nt = NamedTuple{(f,)}((dy,))
return NoTangent(), Tangent{T}(; nt...), ZeroTangent()
nt = NamedTuple{(f,)}((unthunk(dy),))
# not really sure whether this ought to unthunk or not, maybe ProjectTo will anyway, in which case best to be explicit?
return NoTangent(), proj(Tangent{T}(; nt...)), ZeroTangent()
end
return y, getproperty_pullback
end


# from https://github.com/FluxML/Optimisers.jl/pull/105#issuecomment-1229243707
function rrule(::typeof(getfield), x::T, f::Symbol) where T
y = getfield(x, f)
proj = ProjectTo(x)
# valT = Val(T) # perhaps more stable inside closure?
function getfield_pullback(dy)
nt = NamedTuple{(f,)}((unthunk(dy),))
# not really sure whether this ought to unthunk or not, maybe ProjectTo will anyway, in which case best to be explicit?
return NoTangent(), proj(Tangent{T}(; nt...)), ZeroTangent()
end
return y, getfield_pullback
end


function rrule(::YotaRuleConfig, ::typeof(getfield), s::Tuple, f::Int)
y = getfield(s, f)
function tuple_getfield_pullback(dy)
Expand Down
11 changes: 10 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,14 @@
is_kwfunc(f) = (name = string(f); endswith(name, "##kw") || endswith(name, "##kw\""))
is_kwfunc(v::Variable) = is_kwfunc(v._op.val)

function unkwfunc(f, args...)
@assert is_kwfunc(f) "Trying to undo Core.kwfunc() on f, but f is not a kw func"
nokw_f = args[2]
@assert Core.kwfunc(nokw_f) === f
return nokw_f
end


# REPL utils - unstable API! don't use in library code!
Base.:(^)(tape::Tape, i::Integer) = tape[V(i)].val
Base.:(//)(tape::Tape, i::Integer) = tape[V(i)].val
Base.:(:)(tape::Tape, i::Integer) = tape[V(i)].val
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test
using Yota
using Yota: gradtape, gradcheck, update_chainrules_primitives!
using Yota: trace, compile, play!
using Yota: make_rrule, YotaRuleConfig
using Yota: make_rrule!, YotaRuleConfig
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, rrule_via_ad
import ChainRulesTestUtils: test_rrule

Expand Down
17 changes: 4 additions & 13 deletions test/test_cr_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ rrule(::YotaRuleConfig, ::typeof(primitive_test2), x; y=1) = primitive_test2(x;
@testset "chainrules api" begin
config = YotaRuleConfig()

rr = make_rrule(double_inc, 2.0)
val, pb = rr(config, double_inc, 3.0)
make_rrule!(double_inc, 2.0)
val, pb = rrule(config, double_inc, 3.0)
@test val == 7
@test pb(1.0) == (ZeroTangent(), 2.0)

trace(broadcasted, double_dec, [1.0, 2.0])

rr = make_rrule(broadcasted, double_dec, [1.0, 2.0])
val, pb = rr(config, broadcasted, double_dec, [3.0, 4.0])
make_rrule!(broadcasted, double_dec, [1.0, 2.0])
val, pb = rrule(config, broadcasted, double_dec, [3.0, 4.0])
@test val == [5.0, 7.0]
@test pb([1, 1]) == (ZeroTangent(), ZeroTangent(), [2.0, 2.0])

Expand All @@ -56,13 +56,4 @@ rrule(::YotaRuleConfig, ::typeof(primitive_test2), x; y=1) = primitive_test2(x;
dxs = map(unthunk, pb([1, 2, 3]))
@test dxs == (ZeroTangent(), ZeroTangent(), [2.0, 4.0, 6.0])

# This context and corresponding isprimitive() are deprecated
# x, y = rand(2)
# @test isprimitive(CR_CTX, primitive_test, x) == true
# @test isprimitive(CR_CTX, Core.kwfunc(primitive_test), (y=1,), primitive_test, x) == true
# @test isprimitive(CR_CTX, primitive_test, x, y) == false

# @test isprimitive(CR_CTX, primitive_test2, x) == true
# @test isprimitive(CR_CTX, Core.kwfunc(primitive_test2), (y=1,), primitive_test, x,) == true
# @test isprimitive(CR_CTX, primitive_test2, x, y) == false
end
9 changes: 9 additions & 0 deletions test/test_grad.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import Statistics
import Yota: isprimitive, ChainRulesCtx
import ChainRulesCore
import ChainRulesCore: rrule, Tangent, ZeroTangent, NoTangent, @opt_out


@testset "ChainRulesCtx" begin
@test isprimitive(ChainRulesCtx(), sum, rand(3, 4))
@test isprimitive(ChainRulesCtx(), Core.kwfunc(sum), (dims=1,), sum, rand(3, 4))
@test !isprimitive(ChainRulesCtx(), Core.kwfunc(sum), (dims=1,), sum, rand())
end


loss_simple(W, b, x) = sum(W * x .+ b)
loss_double_broadcast(W, b, x) = sum(sin.(W * x) .+ b)
loss_double_broadcast2(b, x) = sum(x .* x .+ b)
Expand Down