From ee0d9acf4f8affd9d72501ac89be1708e1d9baeb Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Sat, 21 Sep 2019 11:29:24 +0100 Subject: [PATCH 1/2] Rename `DNE` -> `DoesNotExist` --- src/rulesets/Base/array.jl | 8 ++++---- src/rulesets/Base/base.jl | 2 +- src/rulesets/Base/broadcast.jl | 2 +- src/rulesets/Base/mapreduce.jl | 8 ++++---- src/rulesets/LinearAlgebra/blas.jl | 12 ++++++------ src/rulesets/LinearAlgebra/factorization.jl | 4 ++-- src/rulesets/Statistics/statistics.jl | 2 +- test/rulesets/Base/array.jl | 10 +++++----- test/rulesets/Base/broadcast.jl | 2 +- test/rulesets/LinearAlgebra/factorization.jl | 6 +++--- test/runtests.jl | 2 +- test/test_util.jl | 8 ++++---- 12 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index fa69b17c7..5cb78318e 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -4,7 +4,7 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}}) function reshape_pullback(Ȳ) - return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DNE()) + return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DoesNotExist()) end return reshape(A, dims), reshape_pullback end @@ -12,7 +12,7 @@ end function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) function reshape_pullback(Ȳ) ∂A = @thunk(reshape(Ȳ, dims)) - return (NO_FIELDS, ∂A, fill(DNE(), length(dims))...) + return (NO_FIELDS, ∂A, fill(DoesNotExist(), length(dims))...) end return reshape(A, dims...), reshape_pullback end @@ -63,14 +63,14 @@ end function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}}) function fill_pullback(Ȳ) - return (NO_FIELDS, @thunk(sum(Ȳ)), DNE()) + return (NO_FIELDS, @thunk(sum(Ȳ)), DoesNotExist()) end return fill(value, dims), fill_pullback end function rrule(::typeof(fill), value::Any, dims::Int...) function fill_pullback(Ȳ) - return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DNE(), length(dims))...) + return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DoesNotExist(), length(dims))...) end return fill(value, dims), fill_pullback end diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 5ff81a21a..e3c446a59 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -58,7 +58,7 @@ @scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) @scalar_rule(hypot(x::Real), sign(x)) @scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω)) -@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DNE())) +@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist())) @scalar_rule(+(x), One()) @scalar_rule(-(x), -1) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 694d1c2d3..2807d63a4 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -24,7 +24,7 @@ end function rrule(::typeof(broadcast), f, x) values, derivs = _cast_diff(f, x) function broadcast_pullback(ΔΩ) - return (NO_FIELDS, DNE(), @thunk(ΔΩ .* derivs)) + return (NO_FIELDS, DoesNotExist(), @thunk(ΔΩ .* derivs)) end return values, broadcast_pullback end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 75a7a887b..298b807d3 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -7,7 +7,7 @@ function rrule(::typeof(map), f, xs...) function map_pullback(ȳ) ntuple(length(xs)+2) do full_i full_i == 1 && return NO_FIELDS - full_i == 2 && return DNE() + full_i == 2 && return DoesNotExist() i = full_i-2 @thunk map(ȳ, xs...) do ȳi, xis... _, pullback = _checked_rrule(f, xis...) @@ -39,7 +39,7 @@ for mf in (:mapreduce, :mapfoldl, :mapfoldr) _, ∂xi = pullback_f(ȳi) extern(∂xi) end - (NO_FIELDS, DNE(), DNE(), ∂x) + (NO_FIELDS, DoesNotExist(), DoesNotExist(), ∂x) end return y, $pullback_name end @@ -67,7 +67,7 @@ end function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:) y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims) function sum_pullback(ȳ) - return NO_FIELDS, DNE(), last(mr_pullback(ȳ)) + return NO_FIELDS, DoesNotExist(), last(mr_pullback(ȳ)) end return y, sum_pullback end @@ -83,7 +83,7 @@ end function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:) y = sum(abs2, x; dims=dims) function sum_abs2_pullback(ȳ) - return (NO_FIELDS, DNE(), @thunk(2ȳ .* x)) + return (NO_FIELDS, DoesNotExist(), @thunk(2ȳ .* x)) end return y, sum_abs2_pullback end diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 5772d2832..261cbe661 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -26,7 +26,7 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) end - return (NO_FIELDS, DNE(), ∂X, DNE(), ∂Y, DNE()) + return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist(), ∂Y, DoesNotExist()) end return Ω, blas_dot_pullback end @@ -60,7 +60,7 @@ function rrule(::typeof(BLAS.nrm2), n, X, incx) ΔΩ = extern(ΔΩ) ∂X = scal!(n, ΔΩ / Ω, blascopy!(n, X, incx, _zeros(X), incx), incx) end - return (NO_FIELDS, DNE(), ∂X, DNE()) + return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist()) end return Ω, nrm2_pullback @@ -92,13 +92,13 @@ function rrule(::typeof(BLAS.asum), n, X, incx) else ΔΩ = extern(ΔΩ) ∂X = @thunk scal!( - n, + n, ΔΩ, blascopy!(n, sign.(X), incx, _zeros(X), incx), incx ) end - return (NO_FIELDS, DNE(), ∂X, DNE()) + return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist()) end return Ω, asum_pullback end @@ -130,7 +130,7 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T}, x̄ -> gemv!('N', α, A, ȳ, one(T), x̄) ) end - return (NO_FIELDS, DNE(), @thunk(dot(ȳ, y) / α), ∂A, ∂x) + return (NO_FIELDS, DoesNotExist(), @thunk(dot(ȳ, y) / α), ∂A, ∂x) end return y, gemv_pullback end @@ -195,7 +195,7 @@ function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, ) end end - return (NO_FIELDS, DNE(), DNE(), @thunk(dot(C̄, C) / α), ∂A, ∂B) + return (NO_FIELDS, DoesNotExist(), DoesNotExist(), @thunk(dot(C̄, C) / α), ∂A, ∂B) end return C, gemv_pullback end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 6fc69ebee..2cdd97827 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -29,7 +29,7 @@ function rrule(::typeof(getproperty), F::SVD, x::Symbol) update = (X̄::NamedTuple{(:U,:S,:V)}) -> _update!(X̄, ∂, x) ∂F = InplaceableThunk(∂, update) - return NO_FIELDS, ∂F, DNE() + return NO_FIELDS, ∂F, DoesNotExist() end return getproperty(F, x), getproperty_svd_pullback end @@ -93,7 +93,7 @@ function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) ∂F = @thunk UpperTriangular(Ȳ') end end - return NO_FIELDS, ∂F, DNE() + return NO_FIELDS, ∂F, DoesNotExist() end return getproperty(F, x), getproperty_cholesky_pullback end diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 0c40fa36b..b267b1e73 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -29,7 +29,7 @@ function rrule(::typeof(mean), f, x::AbstractArray{<:Real}) _, _, ∂sum_x = sum_pullback(ȳ) extern(∂sum_x) / n end - return (NO_FIELDS, DNE(), ∂x) + return (NO_FIELDS, DoesNotExist(), ∂x) end return y_sum / n, mean_pullback end diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index c113b6698..ab1089f8f 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -7,7 +7,7 @@ (s̄, Ā, d̄) = pullback(Ȳ) @test s̄ == NO_FIELDS - @test d̄ isa DNE + @test d̄ isa DoesNotExist @test extern(Ā) == reshape(Ȳ, (5, 4)) B, pullback = rrule(reshape, A, 5, 4) @@ -16,8 +16,8 @@ Ȳ = randn(rng, 4, 5) (s̄, Ā, d̄1, d̄2) = pullback(Ȳ) @test s̄ == NO_FIELDS - @test d̄1 isa DNE - @test d̄2 isa DNE + @test d̄1 isa DoesNotExist + @test d̄2 isa DoesNotExist @test extern(Ā) == reshape(Ȳ, 5, 4) end @@ -56,13 +56,13 @@ end @test y == [44, 44, 44, 44] (ds, dv, dd) = pullback(ones(4)) @test ds === NO_FIELDS - @test dd isa DNE + @test dd isa DoesNotExist @test extern(dv) == 4 y, pullback = rrule(fill, 2.0, (3, 3, 3)) @test y == fill(2.0, (3, 3, 3)) (ds, dv, dd) = pullback(ones(3, 3, 3)) @test ds === NO_FIELDS - @test dd isa DNE + @test dd isa DoesNotExist @test dv ≈ 27.0 end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 9dc592008..17f2663ab 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -6,7 +6,7 @@ @test y == sin.(x) (dself, dsin, dx) = pullback(One()) @test dself == NO_FIELDS - @test dsin == DNE() + @test dsin == DoesNotExist() @test extern(dx) == cos.(x) x̄, ȳ = rand(), rand() diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 0b29cec96..57d873f7e 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -12,7 +12,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo dself1, dF, dp = dF_pullback(Ȳ) @test dself1 === NO_FIELDS - @test dp === DNE() + @test dp === DoesNotExist() ΔF = extern(dF) dself2, dX = dX_pullback(ΔF) @@ -37,7 +37,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo Ȳ = ones(size(Y)...) (dself, dF, dp) = dF_pullback(Ȳ) @test dself === NO_FIELDS - @test dp === DNE() + @test dp === DoesNotExist() ChainRules.accumulate!(X̄, dF) end @test X̄.U ≈ ones(3, 2) atol=1e-6 @@ -64,7 +64,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y))) (dself, dF, dp) = dF_pullback(Ȳ) @test dself === NO_FIELDS - @test dp === DNE() + @test dp === DoesNotExist() # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` # machinery from FiniteDifferences because that isn't set up to respect diff --git a/test/runtests.jl b/test/runtests.jl index 050110cba..aa471c266 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,7 +12,7 @@ using Test # For testing purposes we use a lot of using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, - Zero, One, DNE, Thunk, AbstractDifferential + Zero, One, DoesNotExist, Thunk, AbstractDifferential Random.seed!(1) # Set seed that all testsets should reset to. diff --git a/test/test_util.jl b/test/test_util.jl index 6020d68ff..b41b88964 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -189,8 +189,8 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing) for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd) if x̄_fd === nothing - # The way we've structured the above, this tests that the rule is a DNERule - @test x̄_ad isa DNE + # The way we've structured the above, this tests that the rule is a DoesNotExistRule + @test x̄_ad isa DoesNotExist else @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) end @@ -208,8 +208,8 @@ function Base.isapprox(ad::Wirtinger, fd; kwargs...) error("Finite differencing with Wirtinger rules not implemented") end -function Base.isapprox(d_ad::DNE, d_fd; kwargs...) - error("Tried to differentiate w.r.t. a DNE") +function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) + error("Tried to differentiate w.r.t. a `DoesNotExist`") end function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) From 3097e66093189e3b7b920f1bed1e4efc743bf5fa Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Fri, 25 Oct 2019 19:00:02 +0100 Subject: [PATCH 2/2] Bump lowerbound on ChainRulesCore --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a557d44e0..5a4a85776 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -ChainRulesCore = "0.3, 0.4" +ChainRulesCore = "0.4" FiniteDifferences = "^0.7" julia = "^1.0"