Skip to content

Commit d5c644f

Browse files
committed
make Divider<:Function and use just that to test mean
1 parent 2566e5f commit d5c644f

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/rulesets/Statistics/statistics.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,22 @@
1515
test_rrule(mean, abs, [-4.0, 2.0, 2.0])
1616
test_rrule(mean, log, rand(3, 4) .+ 1)
1717
test_rrule(mean, cbrt, randn(5))
18-
multiplier = Multiplier(2.0)
19-
test_rrule(mean, x->multiplier(x), [2.0, 4.0, 8.0]) # defined in test_helpers.jl
20-
divider = Divider(1 + rand())
21-
test_rrule(mean, x->divider(x), randn(5))
18+
19+
test_rrule(mean, Divider(1 + rand()), randn(5)) # defined in test_helpers.jl
2220

2321
test_rrule(mean, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)
2422

2523
test_rrule(mean, log, rand(ComplexF64, 5))
2624
test_rrule(mean, sqrt, rand(ComplexF64, 5))
27-
test_rrule(mean, abs, rand(ComplexF64, 3, 4))
25+
test_rrule(mean, abs, rand(ComplexF164, 3, 4))
2826

2927
test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
3028
test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
3129
test_rrule(mean, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,)))
3230
end
3331

3432
@testset "Regression Test against StatsBase-like Weighted Mean" begin
35-
struct DummyWeights <: AbstractVector{Float64} # DummyType that looks like StatsBase's Weights types
33+
@eval struct DummyWeights <: AbstractVector{Float64} # DummyType that looks like StatsBase's Weights types
3634
end
3735
# This should hit the fallback "nothing" rule indicating no rule was defined
3836
@test nothing == rrule(ChainRulesTestUtils.TestConfig(), mean, [1.0, 2.0], DummyWeights())

test/test_helpers.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ function ChainRulesCore.rrule(m::Multiplier, y, z)
138138
end
139139

140140
"""
141-
Divider(x)
141+
Divider(x) <: Function
142142
143143
Stores a fixed `x` and divides by it, then squares the result.
144144
@@ -148,8 +148,11 @@ julia> map(Divider(2), [1 2 3 4 10])
148148
1×5 Matrix{Float64}:
149149
0.25 1.0 2.25 4.0 25.0
150150
```
151+
152+
Unlike our other functors, this one subtypes `Function` so can be used to test things
153+
with that restrictiion
151154
"""
152-
struct Divider{T<:Real}
155+
struct Divider{T<:Real} <: Function
153156
x::T
154157
end
155158
(d::Divider)(y::Real) = (y / d.x)^2

0 commit comments

Comments
 (0)