Skip to content

Commit

Permalink
Merge #780
Browse files Browse the repository at this point in the history
780: Move a bunch of no_grad to ChainRules r=oxinabox a=oxinabox

this is the partner to JuliaDiff/ChainRules.jl#252
It will fail til that is merged and tagged

What is left is:

- Types (because JuliaDiff/ChainRulesCore.jl#213) (e.g. `Colon`, `OneTo` `Channel`)
- Things to which the derivative is `Zero()` not `DoesNotExist()` (e.g. `one`, `ones`, `zero`, `zeros`)
- Things that felt too magic: e.g. `Base.eval`


Should I bump patch version and tag a release?

Co-authored-by: Lyndon White <lyndon.white@invenialabs.co.uk>
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
  • Loading branch information
3 people authored Sep 3, 2020
2 parents 4934bc2 + a2026e7 commit 40b9f1e
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 23 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.5.5"
version = "0.5.6"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3, 0.4"
ChainRules = "0.7.0"
ChainRules = "0.7.16"
DiffRules = "1.0"
FillArrays = "0.8, 0.9"
ForwardDiff = "0"
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
"""
@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
Expand Down
9 changes: 1 addition & 8 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@ using Distributed: pmap
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)

@nograd size, length, eachindex, Base.OneTo, axes, Colon(), findfirst, findlast, findall, ones, zeros, one, zero, any, all
@nograd randn, randexp, randn!, randexp!
@static if VERSION > v"1.3"
@nograd Random.default_rng
end

@adjoint Base.rand(rng::AbstractRNG, ::Type{T}, dims...) where {T<:Number} =
rand(rng, T, dims...), _ -> nothing
@nograd ones, zeros, Base.OneTo, Colon(), one, zero

@adjoint Base.vect(xs...) = Base.vect(xs...), Δ ->...,)

Expand Down
7 changes: 1 addition & 6 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
@nograd readline, Base.gc_num, Base.time_ns, Base.print, Base.println, Base.show,
Core.show, Core.print, Core.println, string, repr, Threads.nthreads, Threads.threadid

# Gradient of AD stacks

grad_mut(::AbstractVector) = []
Expand Down Expand Up @@ -47,11 +44,9 @@ end
end
end

@nograd haskey

# Channels

@nograd Channel, schedule
@nograd Channel

grad_mut(ch::Channel) = Channel(ch.sz_max)

Expand Down
2 changes: 0 additions & 2 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ using Base.Broadcast
using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
using NNlib

@nograd Broadcast.combine_styles, Broadcast.result_style

# There's a saying that debugging code is about twice as hard as writing it in
# the first place. So if you're as clever as you can be when writing code, how
# will you ever debug it?
Expand Down
5 changes: 1 addition & 4 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ function accum(x::RefValue, y::RefValue)
end

# Core functions

@nograd Core.apply_type, Core.typeof, nfields, fieldtype, Core.TypeVar, Core.UnionAll,
(==), (===), (<=), (>=), (<), (>), isempty, supertype, Base.typename,
eps, Meta.parse, Base.eval, sleep, isassigned
@nograd eps, Base.eval, Core.TypeVar, Core.UnionAll

@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)

Expand Down
2 changes: 1 addition & 1 deletion src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@nograd floor, ceil, trunc, round, hash, div
@nograd floor, ceil, trunc, round, div

@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Expand Down
16 changes: 16 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ using Zygote, Test, ChainRules
@test mimo_pullback_hitcount[] == 1
end

@testset "all AbstractZero partials" begin
# while ChainRules always has a partial for every input, Zygote combined them all
# to a single `nothing` if they are all zero-like.

not_diff_eg(x, i) = [10, 20][i]
function ChainRules.rrule(::typeof(not_diff_eg), x, i)
function not_diff_eg_pullback(Δ)
return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
end
return not_diff_eg(x, i), not_diff_eg_pullback
end

_, pb = Zygote.pullback(not_diff_eg, 10.4, 2)
@test pb(1.2) === nothing
end

@testset "nested AD hitting identity(::Tuple) pullback" begin
# This is is a particularly fiddly case.
# Its kind of a simplified version of `sin'''(0.5)` but different in some places.
Expand Down
5 changes: 5 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1538,9 +1538,14 @@ end
end

@testset "@nograd" begin
@test gradient(x->eachindex([10,20,30])[1], 11) == (nothing,)

#These are defined in ChainRules, we test them here to check we are handling them right
@test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,)


@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
@test gradient(1) do x
Expand Down

0 comments on commit 40b9f1e

Please sign in to comment.