From d57576dc0dde8027eaaa383087cb88072698f48b Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Sat, 10 Jul 2021 11:57:07 +0300 Subject: [PATCH] Remove generic rrule for broadcasted --- src/chainrules.jl | 29 +---------------------------- src/grad.jl | 2 +- src/gradcheck.jl | 43 +------------------------------------------ test/test_examples.jl | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 38 insertions(+), 71 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 7d59091..c724582 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -20,31 +20,4 @@ function update_chainrules_primitives!() end -is_chainrules_primitive(sig) = sig in CHAIN_RULE_PRIMITIVES[] - - -################################################################ - -# from Zygote: -# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185 -struct StaticGetter{i} end -(::StaticGetter{i})(v) where {i} = v[i] - -@generated function _unzip(tuples, ::Val{N}) where {N} - Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...) -end - -function unzip(tuples) - N = length(first(tuples)) - _unzip(tuples, Val(N)) -end - - -function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F - ys, pbs = unzip(rrule.(f, args...)) - function pullback(Δ) - dxs = map((pb, Δ) -> pb(Δ), pbs, Δ) - return NoTangent(), unzip(dxs)... - end - return ys, pullback -end +is_chainrules_primitive(sig) = sig in CHAIN_RULE_PRIMITIVES[] \ No newline at end of file diff --git a/src/grad.jl b/src/grad.jl index 6ff7924..86051c7 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -126,7 +126,7 @@ function step_back!(tape::Tape, y::Variable, deriv_todo::Vector{Variable}) else sig_str = join(["::$T" for T in Ghost.call_signature(tape, tape[y]).parameters], ", ") error("No deriative rule found for op $(tape[y]), " * - "try defining it using ChainRules.rrule($sig_str) = ...") + "try defining it using ChainRulesCore.rrule($sig_str) = ...") end for (i, x) in enumerate(y_fargs) if x isa V diff --git a/src/gradcheck.jl b/src/gradcheck.jl index 0151ee6..fc8297e 100644 --- a/src/gradcheck.jl +++ b/src/gradcheck.jl @@ -6,41 +6,6 @@ function ngradient(f, args...) end -# # from https://github.com/FluxML/Zygote.jl/blob/master/test/gradcheck.jl - -# function ngradient(f, xs::AbstractArray...) -# grads = zero.(xs) -# for (x, Δ) in zip(xs, grads), i in 1:length(x) -# δ = sqrt(eps()) -# tmp = x[i] -# x[i] = tmp - δ/2 -# y1 = f(xs...) -# x[i] = tmp + δ/2 -# y2 = f(xs...) -# x[i] = tmp -# Δ[i] = (y2-y1)/δ -# end -# return grads -# end - - -# function ngradient2(f, xs, n) -# x = xs[n] -# Δ = zero(x) -# for i in 1:length(x) -# δ = sqrt(eps()) -# tmp = x[i] -# x[i] = tmp - δ/2 -# y1 = f(xs...) -# x[i] = tmp + δ/2 -# y2 = f(xs...) -# x[i] = tmp -# Δ[i] = (y2-y1)/δ -# end -# return Δ -# end - - function gradcheck(f, args...) y_grads = grad(f, args...)[2] # don't check gradient w.r.t. function since ngradient can't do it @@ -51,10 +16,4 @@ function gradcheck(f, args...) push!(results, isapprox(y_grads[n], n_grad[n], rtol = 1e-5, atol = 1e-5)) end return all(results) -end - - -# gradcheck = gradcheck2 - -# # gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) -# # gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) +end \ No newline at end of file diff --git a/test/test_examples.jl b/test/test_examples.jl index e3c8dfc..d35e190 100644 --- a/test/test_examples.jl +++ b/test/test_examples.jl @@ -1,5 +1,40 @@ using Statistics using Random +import ChainRulesCore: rrule, NoTangent + + +################################################################ + +# TODO: migrate to a common broadcasting rrule + +# from Zygote: +# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185 +struct StaticGetter{i} end +(::StaticGetter{i})(v) where {i} = v[i] + +@generated function _unzip(tuples, ::Val{N}) where {N} + Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...) +end + +function unzip(tuples) + N = length(first(tuples)) + _unzip(tuples, Val(N)) +end + + +function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F + ys, pbs = unzip(rrule.(f, args...)) + function pullback(Δ) + dxs = map((pb, Δ) -> pb(Δ), pbs, Δ) + return NoTangent(), unzip(dxs)... + end + return ys, pullback +end + +Yota.update_chainrules_primitives!() + +################################################################ + obj(Y, X, b) = mean((Y .- X * b) .^ 2.0) # objective to minimize