Skip to content

Commit

Permalink
Support InverseFunctions where applicable (#130)
Browse files Browse the repository at this point in the history
* Support InverseFunctions where applicable

* Simplify inverse definitions

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Activate inverse tests

* Fix InverseFunctions.inverse tests

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
oschulz and devmotion authored Nov 10, 2021
1 parent f13b618 commit 69e6345
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.9.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -12,6 +13,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
ChainRulesCore = "1"
InverseFunctions = "0.1"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3.2"
Reexport = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Base: Math.@horner
using Reexport
using SpecialFunctions
import ChainRulesCore
import InverseFunctions

# reexports
@reexport using IrrationalConstants:
Expand Down Expand Up @@ -260,5 +261,6 @@ include(joinpath("distrs", "tdist.jl"))
include(joinpath("distrs", "srdist.jl"))

include("chainrules.jl")
include("inverse.jl")

end # module
11 changes: 11 additions & 0 deletions src/inverse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
InverseFunctions.inverse(::typeof(normcdf)) = norminvcdf
InverseFunctions.inverse(::typeof(norminvcdf)) = normcdf

InverseFunctions.inverse(::typeof(normccdf)) = norminvccdf
InverseFunctions.inverse(::typeof(norminvccdf)) = normccdf

InverseFunctions.inverse(::typeof(normlogcdf)) = norminvlogcdf
InverseFunctions.inverse(::typeof(norminvlogcdf)) = normlogcdf

InverseFunctions.inverse(::typeof(normlogccdf)) = norminvlogccdf
InverseFunctions.inverse(::typeof(norminvlogccdf)) = normlogccdf
11 changes: 11 additions & 0 deletions test/inverse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using StatsFuns
using Test
using InverseFunctions

@testset "inverse" begin
x = 0.7
InverseFunctions.test_inverse(norminvcdf, x)
InverseFunctions.test_inverse(norminvccdf, x)
InverseFunctions.test_inverse(norminvlogcdf, log(x))
InverseFunctions.test_inverse(norminvlogccdf, log(x))
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tests = ["rmath", "generic", "misc", "chainrules"]
tests = ["rmath", "generic", "misc", "chainrules", "inverse"]

for t in tests
fp = "$t.jl"
Expand Down

0 comments on commit 69e6345

Please sign in to comment.