Skip to content

Commit

Permalink
Implement erfinv for Float32 and Float64
Browse files Browse the repository at this point in the history
The implementations are based on the Julia package SpecialFunctions.jl.
  • Loading branch information
ararslan committed Dec 31, 2023
1 parent 5b2df9b commit 1564acf
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/complex.dx
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def complex_erf(x:Complex) -> Complex =
def complex_erfc(x:Complex) -> Complex =
todo

def complex_erfinv(x:Complex) -> Complex =
todo

def complex_log1p(x:Complex) -> Complex =
case x.re == 0.0 of
True -> x
Expand Down Expand Up @@ -130,3 +133,4 @@ instance Floating(Complex)
def lgamma(x) = complex_lgamma(x)
def erf(x) = complex_erf(x)
def erfc(x) = complex_erfc(x)
def erfinv(x) = complex_erfinv(x)
125 changes: 125 additions & 0 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ interface Floating(a:Type)
lgamma : (a) -> a
erf : (a) -> a
erfc : (a) -> a
erfinv : (a) -> a

def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y)

Expand All @@ -1066,6 +1067,127 @@ def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f6
def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))
,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))))

# Polynomial evaluation by Horner's method
def unsafe_horner(x:a, ys:n=>a) -> a given (a|Add|Mul, n|Ix) =
n' = unsafe_i_to_n(n_to_i(size n) - 1)
yield_state ys[unsafe_from_ordinal n'] \ref. rof i:(Fin n').
ref := ys[unsafe_from_ordinal (ordinal i)] + x * get ref

# `erfinv` implementations for `Float32` and `Float64` are based on those in Julia in
# https://github.com/JuliaMath/SpecialFunctions.jl, which uses the following reference:
# Blair, J. M., Edwards, C. A., & Johnson, J. H. (1976). Rational Chebyshev approximations
# for the inverse of the error function. In Mathematics of Computation (Vol. 30, Issue 136,
# pp. 827–830). American Mathematical Society (AMS).
# https://doi.org/10.1090/s0025-5718-1976-0421040-7
def float32_erfinv(x:Float32) -> Float32 =
a = select(x > 0.0, x, -x)
if a >= 1.0
then
inf = f_to_f32(1.0 / 0.0)
if x == 1.0
then inf
else
if x == -1.0
then -inf
else f_to_f32(0.0 / 0.0) # TODO: this should probably error but `error` is not defined yet
else
if a <= 0.75 # Blair table 10
then
t = x * x - 0.5625
p1 = unsafe_horner t [-0.130959967422e+2, 0.26785225760e+2, -0.9289057365e+1]
p2 = unsafe_horner t [-0.120749426297e+2, 0.30960614529e+2, -0.17149977991e+2, 0.1e+1]
f_to_f32(x * (p1 / p2))
else
if a <= 0.9375 # Blair table 29
then
t = x * x - 0.87890625
p1 = unsafe_horner t [-0.12402565221, 0.10688059574e+1, -0.19594556078e+1, 0.4230581357]
p2 = unsafe_horner t [-0.8827697997e-1, 0.8900743359, -0.21757031196e+1, 0.1e+1]
f_to_f32(x * (p1 / p2))
else # Blair table 50
t = 1.0 / %sqrt(-%log1p(-a))
p1 = unsafe_horner t [-0.8827697997e-1, 0.8900743359, -0.21757031196e+1, 0.1e+1]
p2 = unsafe_horner t [0.155024849822, 0.1385228141995e+1, 0.1e+1]
s = select(x > 0.0, t, select(x < 0.0, (-t), 0.0))
f_to_f32(p1 / (s * p2))

def float64_erfinv(x:Float64) -> Float64 =
zero64 = (zero::Float64)
one64 = (one::Float64)
a = select(x > zero64, x, %fsub(zero64, x))
if a >= one64
then
inf = %fdiv(one64, zero64)
if x == one64
then inf
else
if x == f_to_f64(-1.0)
then %fsub(zero64, inf)
else %fdiv(zero64, zero64)
else
if a <= f_to_f64(0.75) # Blair table 17
then
t = %fsub(%fmul(x, x), f_to_f64(0.5625))
p1 = unsafe_horner t [f_to_f64( 0.160304955844066229311e2),
f_to_f64(-0.90784959262960326650e2),
f_to_f64( 0.18644914861620987391e3),
f_to_f64(-0.16900142734642382420e3),
f_to_f64( 0.6545466284794487048e2),
f_to_f64(-0.864213011587247794e1),
f_to_f64( 0.1760587821390590)]
p2 = unsafe_horner t [f_to_f64( 0.147806470715138316110e2),
f_to_f64(-0.91374167024260313936e2),
f_to_f64( 0.21015790486205317714e3),
f_to_f64(-0.22210254121855132366e3),
f_to_f64( 0.10760453916055123830e3),
f_to_f64(-0.206010730328265443e2),
f_to_f64( 0.1e1)]
%fmul(x, %fdiv(p1, p2))
else
if a <= f_to_f64(0.9375) # Blair table 37
then
t = %fsub(%fmul(x, x), f_to_f64(0.87890625))
p1 = unsafe_horner t [f_to_f64(-0.152389263440726128e-1),
f_to_f64( 0.3444556924136125216),
f_to_f64(-0.29344398672542478687e1),
f_to_f64( 0.11763505705217827302e2),
f_to_f64(-0.22655292823101104193e2),
f_to_f64( 0.19121334396580330163e2),
f_to_f64(-0.5478927619598318769e1),
f_to_f64( 0.237516689024448)]
p2 = unsafe_horner t [f_to_f64(-0.108465169602059954e-1),
f_to_f64( 0.2610628885843078511),
f_to_f64(-0.24068318104393757995e1),
f_to_f64( 0.10695129973387014469e2),
f_to_f64(-0.23716715521596581025e2),
f_to_f64( 0.24640158943917284883e2),
f_to_f64(-0.10014376349783070835e2),
f_to_f64( 0.1e1)]
%fmul(x, %fdiv(p1, p2))
else # Blair table 57
t = %fdiv(one64, %sqrt(%fsub(zero64, %log1p(%fsub(zero64, a)))))
p1 = unsafe_horner t [f_to_f64(0.10501311523733438116e-3),
f_to_f64(0.1053261131423333816425e-1),
f_to_f64(0.26987802736243283544516),
f_to_f64(0.23268695788919690806414e1),
f_to_f64(0.71678547949107996810001e1),
f_to_f64(0.85475611822167827825185e1),
f_to_f64(0.68738088073543839802913e1),
f_to_f64(0.3627002483095870893002e1),
f_to_f64(0.886062739296515468149)]
p2 = unsafe_horner t [f_to_f64(0.10501266687030337690e-3),
f_to_f64(0.1053286230093332753111e-1),
f_to_f64(0.27019862373751554845553),
f_to_f64(0.23501436397970253259123e1),
f_to_f64(0.76078028785801277064351e1),
f_to_f64(0.111815861040569078273451e2),
f_to_f64(0.119487879184353966678438e2),
f_to_f64(0.81922409747269907893913e1),
f_to_f64(0.4099387907636801536145e1),
f_to_f64(0.1e1)]
s = select(x > zero64, t, select(x < zero64, %fsub(zero64, t), zero64))
%fdiv(p1, %fmul(s, p2))

instance Floating(Float64)
def exp(x) = %exp(x)
def exp2(x) = %exp2(x)
Expand All @@ -1087,6 +1209,7 @@ instance Floating(Float64)
def lgamma(x)= %lgamma(x)
def erf(x) = %erf(x)
def erfc(x) = %erfc(x)
def erfinv(x)= float64_erfinv(x)

instance Floating(Float32)
def exp(x) = %exp(x)
Expand All @@ -1109,6 +1232,7 @@ instance Floating(Float32)
def lgamma(x)= %lgamma(x)
def erf(x) = %erf(x)
def erfc(x) = %erfc(x)
def erfinv(x)= float32_erfinv(x)

'## Raw pointer operations

Expand Down Expand Up @@ -1249,6 +1373,7 @@ instance Floating(n=>a) given (a|Floating, n|Ix)
def lgamma(x) = each x lgamma
def erf(x) = each x erf
def erfc(x) = each x erfc
def erfinv(x) = each x erfinv

'### Reductions

Expand Down
15 changes: 15 additions & 0 deletions tests/eval-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ fun = \y. sum (map n_to_f arr) + y
:p f_to_i $ round 3.6
> 4

:p erfinv(f_to_f64 0.84270079294971486934)
> 1.

:p erfinv 1.0
> inf

-- TODO: This should actually be an error since it's outside of the domain of the function
:p erfinv 2.0
> nan

:p
xs = each [-0.99, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 0.99] f_to_f64
erf(erfinv xs) ~~ erfinv(erf xs) && erfinv(xs) ~~ each xs \x. zero - erfinv(zero - x)
> True

s = 1.0

:p s
Expand Down

0 comments on commit 1564acf

Please sign in to comment.