Skip to content

Commit b260a9f

Browse files
committed
Opt out the problematic StatsBase-like weighted mean(::AbstractVector, ::AbstractVector)
1 parent 1b89663 commit b260a9f

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/rulesets/Statistics/statistics.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ function rrule(
3131
return y_sum / n, mean_pullback_f
3232
end
3333

34+
# Similar to https://github.com/JuliaDiff/ChainRules.jl/issues/522
35+
# The rule above assumes `f` is callable. Arrays are not, this came up when taking
36+
# the mean arrays with weights in StatsBase
37+
@opt_out ChainRulesCore.rrule(
38+
config::RuleConfig{>:HasReverseMode},
39+
::typeof(mean),
40+
x::AbstractArray,
41+
wt::AbstractArray{<:Union{Real,Complex,AbstractArray}};
42+
dims=:
43+
)
44+
45+
3446
#####
3547
##### variance
3648
#####

test/rulesets/Statistics/statistics.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
2929
test_rrule(mean, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,)))
3030
end
31+
32+
@testset "Regression Test against StatsBase-like Weighted Mean" begin
33+
@eval struct DummyWeights <: AbstractVector{Float64} # DummyType that looks like StatsBase's Weights types
34+
end
35+
# This should return nothing as we have no rule for this. (we opted opt)
36+
@test nothing == rrule(ChainRulesTestUtils.TestConfig(), mean, [1.0, 2.0], DummyWeights())
37+
end
3138
end
3239

3340
@testset "variation: $var" for var in (std, var)

0 commit comments

Comments
 (0)