From 8424476fda14585decc476c163afea8e6666b8a7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 9 Dec 2022 14:09:34 +0000 Subject: [PATCH 1/4] revert revert #615 (i.e. revert #619) --- src/rulesets/Statistics/statistics.jl | 24 +++++++++++++++++------ test/rulesets/Statistics/statistics.jl | 24 ++++++++++++++++++++++- test/test_helpers.jl | 27 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index cf4707a2d..07e0d62d6 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -6,19 +6,31 @@ _denom(x, dims) = size(x, dims) _denom(x, dims::Colon) = length(x) _denom(x, dims::Union{Tuple, AbstractArray}) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1) -# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36 -# https://github.com/JuliaDiff/ChainRules.jl/issues/85 -function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:) - y_sum, sum_pullback = rrule(sum, x; dims=dims) +function rrule(::typeof(mean), x::AbstractArray{<:Union{Real,Complex,AbstractArray}}; dims=:) + y_sum, sum_pullback = rrule(sum, x; dims) n = _denom(x, dims) function mean_pullback(ȳ) - _, ∂sum_x = sum_pullback(ȳ) - ∂x = unthunk(∂sum_x) / n + _, ∂x = sum_pullback(unthunk(ȳ) / n) return (NoTangent(), ∂x) end return y_sum / n, mean_pullback end +function rrule( + config::RuleConfig{>:HasReverseMode}, + ::typeof(mean), + f::F, + x::AbstractArray{T}; + dims=:, +) where {F, T<:Union{Real,Complex,AbstractArray}} + y_sum, sum_pullback = rrule(config, sum, f, x; dims) + n = _denom(x, dims) + function mean_pullback_f(ȳ) + return sum_pullback(unthunk(ȳ) / n) + end + return y_sum / n, mean_pullback_f +end + ##### ##### variance ##### diff --git a/test/rulesets/Statistics/statistics.jl b/test/rulesets/Statistics/statistics.jl index 9ca46e13a..3e0b36a82 100644 --- a/test/rulesets/Statistics/statistics.jl +++ b/test/rulesets/Statistics/statistics.jl @@ -1,10 +1,32 @@ @testset "mean" begin - @testset "Basic" begin + @testset "mean(x)" begin @gpu test_rrule(mean, randn(9)) + test_rrule(mean, randn(ComplexF64,2,4)) + test_rrule(mean, transpose(rand(3))) + test_rrule(mean, [rand(3) for _ in 1:4]; check_inferred=false) end @testset "with dims kwargs" begin @gpu test_rrule(mean, randn(9); fkwargs=(;dims=1)) @gpu test_rrule(mean, randn(9,4); fkwargs=(;dims=2)) + @gpu test_rrule(mean, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(;dims=2), check_inferred=false) + end + @testset "mean(f, x)" begin + # This shares its implementation with sum(f, x). Similar tests should cover all cases: + test_rrule(mean, abs, [-4.0, 2.0, 2.0]) + test_rrule(mean, log, rand(3, 4) .+ 1) + test_rrule(mean, cbrt, randn(5)) + test_rrule(mean, Multiplier(2.0), [2.0, 4.0, 8.0]) # defined in test_helpers.jl + test_rrule(mean, Divider(1 + rand()), randn(5)) + + test_rrule(mean, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false) + + test_rrule(mean, log, rand(ComplexF64, 5)) + test_rrule(mean, sqrt, rand(ComplexF64, 5)) + test_rrule(mean, abs, rand(ComplexF64, 3, 4)) + + test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1)) + test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2)) + test_rrule(mean, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,))) end end diff --git a/test/test_helpers.jl b/test/test_helpers.jl index 1f456931b..1aa85d52c 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -137,6 +137,28 @@ function ChainRulesCore.rrule(m::Multiplier, y, z) return m(y, z), Multiplier_pullback_3 end +""" + Divider(x) + +Stores a fixed `x` and divides by it, then squares the result. + +Especially for testing the gradient of higher order functions with respect to `x`. +``` +julia> map(Divider(2), [1 2 3 4 10]) +1×5 Matrix{Float64}: + 0.25 1.0 2.25 4.0 25.0 +``` +""" +struct Divider{T<:Real} + x::T +end +(d::Divider)(y::Real) = (y / d.x)^2 + +function ChainRulesCore.rrule(d::Divider, y::Real) + Divider_pullback(dΩ) = (Tangent{typeof(d)}(; x = -2 * dΩ * y^2 / d.x^3), 2 * dΩ * y / d.x^2) + return d(y), Divider_pullback +end + """ Counter() @@ -198,6 +220,11 @@ ChainRulesCore.frule((_, Δx), ::typeof(flog), x::Number) = log(x), inv(x) * Δx test_rrule(Multiplier(1.0 + 2im), 3.0 + 4im, 5.0 - 6im) test_rrule(Multiplier(rand(2,3)), rand(3,4), rand(4,5)) end + + @testset "Divider" begin + test_rrule(Divider(2.3), 4.5) + test_rrule(Divider(0.2), -3.4) + end @testset "Counter" begin c = Counter() From 1b896632c4a008a675912cf06ff8b366df1ed5c7 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 9 Dec 2022 15:57:10 +0000 Subject: [PATCH 2/4] remove nongpu compatble gpu test --- test/rulesets/Statistics/statistics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Statistics/statistics.jl b/test/rulesets/Statistics/statistics.jl index 3e0b36a82..37d3a6adc 100644 --- a/test/rulesets/Statistics/statistics.jl +++ b/test/rulesets/Statistics/statistics.jl @@ -8,7 +8,7 @@ @testset "with dims kwargs" begin @gpu test_rrule(mean, randn(9); fkwargs=(;dims=1)) @gpu test_rrule(mean, randn(9,4); fkwargs=(;dims=2)) - @gpu test_rrule(mean, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(;dims=2), check_inferred=false) + test_rrule(mean, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(;dims=2), check_inferred=false) end @testset "mean(f, x)" begin # This shares its implementation with sum(f, x). Similar tests should cover all cases: From b260a9f6657d62c95d8621bcd83510d47ba3931a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 9 Dec 2022 19:39:41 +0000 Subject: [PATCH 3/4] Opt out the problematic StatsBase-like weighted mean(::AbstractVector, ::AbstractVector) --- src/rulesets/Statistics/statistics.jl | 12 ++++++++++++ test/rulesets/Statistics/statistics.jl | 7 +++++++ 2 files changed, 19 insertions(+) diff --git a/src/rulesets/Statistics/statistics.jl b/src/rulesets/Statistics/statistics.jl index 07e0d62d6..08be133fd 100644 --- a/src/rulesets/Statistics/statistics.jl +++ b/src/rulesets/Statistics/statistics.jl @@ -31,6 +31,18 @@ function rrule( return y_sum / n, mean_pullback_f end +# Similar to https://github.com/JuliaDiff/ChainRules.jl/issues/522 +# The rule above assumes `f` is callable. Arrays are not, this came up when taking +# the mean arrays with weights in StatsBase +@opt_out ChainRulesCore.rrule( + config::RuleConfig{>:HasReverseMode}, + ::typeof(mean), + x::AbstractArray, + wt::AbstractArray{<:Union{Real,Complex,AbstractArray}}; + dims=: +) + + ##### ##### variance ##### diff --git a/test/rulesets/Statistics/statistics.jl b/test/rulesets/Statistics/statistics.jl index 37d3a6adc..1b41fac0b 100644 --- a/test/rulesets/Statistics/statistics.jl +++ b/test/rulesets/Statistics/statistics.jl @@ -28,6 +28,13 @@ test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2)) test_rrule(mean, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,))) end + + @testset "Regression Test against StatsBase-like Weighted Mean" begin + @eval struct DummyWeights <: AbstractVector{Float64} # DummyType that looks like StatsBase's Weights types + end + # This should return nothing as we have no rule for this. (we opted opt) + @test nothing == rrule(ChainRulesTestUtils.TestConfig(), mean, [1.0, 2.0], DummyWeights()) + end end @testset "variation: $var" for var in (std, var) From f767c34f45ae7aaa8b18c21dbd96031eb9f2f85b Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 9 Dec 2022 19:40:43 +0000 Subject: [PATCH 4/4] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c6f8c9543..51337c18f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.45.0" +version = "1.46.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"