Skip to content

Commit

Permalink
Merge pull request #308 from devmotion/dw/notimplemented
Browse files Browse the repository at this point in the history
Use `ChainRulesCore.@not_implemented` and extend tests
  • Loading branch information
andreasnoack authored May 17, 2021
2 parents feccbf4 + c1f7012 commit a48ba24
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 32 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
OpenSpecFun_jll = "efe28fd5-8261-553b-a9e1-b2916fc3738e"

[compat]
ChainRulesCore = "0.9"
ChainRulesTestUtils = "0.6.3"
ChainRulesCore = "0.9.40"
ChainRulesTestUtils = "0.6.8"
LogExpFunctions = "0.2"
OpenSpecFun_jll = "0.5"
julia = "1.3"
Expand Down
19 changes: 12 additions & 7 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
const BESSEL_ORDER_INFO = """
derivatives of Bessel functions with respect to the order are not implemented currently:
https://github.com/JuliaMath/SpecialFunctions.jl/issues/160
"""

ChainRulesCore.@scalar_rule(airyai(x), airyaiprime(x))
ChainRulesCore.@scalar_rule(airyaiprime(x), x * airyai(x))
ChainRulesCore.@scalar_rule(airybi(x), airybiprime(x))
Expand Down Expand Up @@ -31,49 +36,49 @@ ChainRulesCore.@scalar_rule(trigamma(x), polygamma(2, x))
ChainRulesCore.@scalar_rule(
besselj(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(besselj- 1, x) - besselj+ 1, x)) / 2
),
)
ChainRulesCore.@scalar_rule(
besseli(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(besseli- 1, x) + besseli+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
bessely(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(bessely- 1, x) - bessely+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
besselk(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
-(besselk- 1, x) + besselk+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh1(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh1- 1, x) - hankelh1+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
hankelh2(ν, x),
(
ChainRulesCore.@thunk(error("not implemented")),
ChainRulesCore.@not_implemented(BESSEL_ORDER_INFO),
(hankelh2- 1, x) - hankelh2+ 1, x)) / 2,
),
)
ChainRulesCore.@scalar_rule(
polygamma(m, x),
(
ChainRulesCore.@thunk(error("not implemented")),
ChainRulesCore.DoesNotExist(),
polygamma(m + 1, x),
),
)
Expand Down
80 changes: 57 additions & 23 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "chainrules" begin
Random.seed!(1)

@testset "general" begin
@testset "general: single input" begin
for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
test_scalar(erf, x)
test_scalar(erfc, x)
Expand All @@ -12,9 +12,6 @@
test_scalar(airybi, x)
test_scalar(airybiprime, x)

test_scalar(besselj0, x)
test_scalar(besselj1, x)

test_scalar(erfcx, x)
test_scalar(dawson, x)

Expand All @@ -28,37 +25,74 @@
end

if x isa Real && x > 0 || x isa Complex
test_scalar(bessely0, x)
test_scalar(bessely1, x)
test_scalar(gamma, x)
test_scalar(digamma, x)
test_scalar(trigamma, x)
end
end
end

@testset "Bessel functions" begin
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
test_scalar(besselj0, x)
test_scalar(besselj1, x)

isreal(x) && x < 0 && continue

test_scalar(bessely0, x)
test_scalar(bessely1, x)

for nu in (-1.5, 2.2, 4.0)
test_frule(besseli, nu, x)
test_rrule(besseli, nu, x)

@testset "beta and logbeta" begin
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
for _x in test_points, _y in test_points
# ensure all complex if any complex for FiniteDifferences
x, y = promote(_x, _y)
test_frule(beta, x, y)
test_rrule(beta, x, y)
test_frule(besselj, nu, x)
test_rrule(besselj, nu, x)

test_frule(logbeta, x, y)
test_rrule(logbeta, x, y)
test_frule(besselk, nu, x)
test_rrule(besselk, nu, x)

test_frule(bessely, nu, x)
test_rrule(bessely, nu, x)

# use complex numbers in `rrule` for FiniteDifferences
test_frule(hankelh1, nu, x)
test_rrule(hankelh1, nu, complex(x))

# use complex numbers in `rrule` for FiniteDifferences
test_frule(hankelh2, nu, x)
test_rrule(hankelh2, nu, complex(x))
end
end
end

@testset "log gamma and co" begin
# It is important that we have negative numbers with both odd and even integer parts
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
isreal(x) && x < 0 && continue
test_scalar(loggamma, x)
@testset "beta and logbeta" begin
test_points = (1.5, 2.5, 10.5, 1.6 + 1.6im, 1.6 - 1.6im, 4.6 + 1.6im)
for _x in test_points, _y in test_points
# ensure all complex if any complex for FiniteDifferences
x, y = promote(_x, _y)
test_frule(beta, x, y)
test_rrule(beta, x, y)

isreal(x) || continue
test_frule(logabsgamma, x)
test_rrule(logabsgamma, x; output_tangent=(randn(), randn()))
test_frule(logbeta, x, y)
test_rrule(logbeta, x, y)
end
end

@testset "log gamma and co" begin
# It is important that we have negative numbers with both odd and even integer parts
for x in (1.5, 2.5, 10.5, -0.6, -2.6, -3.3, 1.6 + 1.6im, 1.6 - 1.6im, -4.6 + 1.6im)
for m in (0, 1, 2, 3)
test_frule(polygamma, m, x)
test_rrule(polygamma, m, x)
end

isreal(x) && x < 0 && continue
test_scalar(loggamma, x)

isreal(x) || continue
test_frule(logabsgamma, x)
test_rrule(logabsgamma, x; output_tangent=(randn(), randn()))
end
end
end

0 comments on commit a48ba24

Please sign in to comment.