|
15 | 15 | test_rrule(mean, abs, [-4.0, 2.0, 2.0]) |
16 | 16 | test_rrule(mean, log, rand(3, 4) .+ 1) |
17 | 17 | test_rrule(mean, cbrt, randn(5)) |
18 | | - multiplier = Multiplier(2.0) |
19 | | - test_rrule(mean, x->multiplier(x), [2.0, 4.0, 8.0]) # defined in test_helpers.jl |
20 | | - divider = Divider(1 + rand()) |
21 | | - test_rrule(mean, x->divider(x), randn(5)) |
| 18 | + |
| 19 | + test_rrule(mean, Divider(1 + rand()), randn(5)) # defined in test_helpers.jl |
22 | 20 |
|
23 | 21 | test_rrule(mean, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false) |
24 | 22 |
|
25 | 23 | test_rrule(mean, log, rand(ComplexF64, 5)) |
26 | 24 | test_rrule(mean, sqrt, rand(ComplexF64, 5)) |
27 | | - test_rrule(mean, abs, rand(ComplexF64, 3, 4)) |
| 25 | + test_rrule(mean, abs, rand(ComplexF164, 3, 4)) |
28 | 26 |
|
29 | 27 | test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1)) |
30 | 28 | test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2)) |
31 | 29 | test_rrule(mean, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,))) |
32 | 30 | end |
33 | 31 |
|
34 | 32 | @testset "Regression Test against StatsBase-like Weighted Mean" begin |
35 | | - struct DummyWeights <: AbstractVector{Float64} # DummyType that looks like StatsBase's Weights types |
| 33 | + @eval struct DummyWeights <: AbstractVector{Float64} # DummyType that looks like StatsBase's Weights types |
36 | 34 | end |
37 | 35 | # This should hit the fallback "nothing" rule indicating no rule was defined |
38 | 36 | @test nothing == rrule(ChainRulesTestUtils.TestConfig(), mean, [1.0, 2.0], DummyWeights()) |
|
0 commit comments