Skip to content

Commit

Permalink
Merge pull request #96 from dfdx/no-rrule-piracy
Browse files Browse the repository at this point in the history
Remove generic rrule for broadcasted
  • Loading branch information
dfdx authored Jul 10, 2021
2 parents 1fa7d08 + d57576d commit aee7ea3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 71 deletions.
29 changes: 1 addition & 28 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,4 @@ function update_chainrules_primitives!()
end


is_chainrules_primitive(sig) = sig in CHAIN_RULE_PRIMITIVES[]


################################################################

# from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]

@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i 1:N)...)
end

function unzip(tuples)
N = length(first(tuples))
_unzip(tuples, Val(N))
end


function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
ys, pbs = unzip(rrule.(f, args...))
function pullback(Δ)
dxs = map((pb, Δ) -> pb(Δ), pbs, Δ)
return NoTangent(), unzip(dxs)...
end
return ys, pullback
end
is_chainrules_primitive(sig) = sig in CHAIN_RULE_PRIMITIVES[]
2 changes: 1 addition & 1 deletion src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function step_back!(tape::Tape, y::Variable, deriv_todo::Vector{Variable})
else
sig_str = join(["::$T" for T in Ghost.call_signature(tape, tape[y]).parameters], ", ")
error("No deriative rule found for op $(tape[y]), " *
"try defining it using ChainRules.rrule($sig_str) = ...")
"try defining it using ChainRulesCore.rrule($sig_str) = ...")
end
for (i, x) in enumerate(y_fargs)
if x isa V
Expand Down
43 changes: 1 addition & 42 deletions src/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,6 @@ function ngradient(f, args...)
end


# # from https://github.com/FluxML/Zygote.jl/blob/master/test/gradcheck.jl

# function ngradient(f, xs::AbstractArray...)
# grads = zero.(xs)
# for (x, Δ) in zip(xs, grads), i in 1:length(x)
# δ = sqrt(eps())
# tmp = x[i]
# x[i] = tmp - δ/2
# y1 = f(xs...)
# x[i] = tmp + δ/2
# y2 = f(xs...)
# x[i] = tmp
# Δ[i] = (y2-y1)/δ
# end
# return grads
# end


# function ngradient2(f, xs, n)
# x = xs[n]
# Δ = zero(x)
# for i in 1:length(x)
# δ = sqrt(eps())
# tmp = x[i]
# x[i] = tmp - δ/2
# y1 = f(xs...)
# x[i] = tmp + δ/2
# y2 = f(xs...)
# x[i] = tmp
# Δ[i] = (y2-y1)/δ
# end
# return Δ
# end


function gradcheck(f, args...)
y_grads = grad(f, args...)[2]
# don't check gradient w.r.t. function since ngradient can't do it
Expand All @@ -51,10 +16,4 @@ function gradcheck(f, args...)
push!(results, isapprox(y_grads[n], n_grad[n], rtol = 1e-5, atol = 1e-5))
end
return all(results)
end


# gradcheck = gradcheck2

# # gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
# # gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
end
35 changes: 35 additions & 0 deletions test/test_examples.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,40 @@
using Statistics
using Random
import ChainRulesCore: rrule, NoTangent


################################################################

# TODO: migrate to a common broadcasting rrule

# from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]

@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i 1:N)...)
end

function unzip(tuples)
N = length(first(tuples))
_unzip(tuples, Val(N))
end


function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
ys, pbs = unzip(rrule.(f, args...))
function pullback(Δ)
dxs = map((pb, Δ) -> pb(Δ), pbs, Δ)
return NoTangent(), unzip(dxs)...
end
return ys, pullback
end

Yota.update_chainrules_primitives!()

################################################################



obj(Y, X, b) = mean((Y .- X * b) .^ 2.0) # objective to minimize
Expand Down

0 comments on commit aee7ea3

Please sign in to comment.