From 48dd8741b6517af7aa4be50440d1dd0ebe522d1c Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Thu, 28 Oct 2021 21:00:36 +0200 Subject: [PATCH] =?UTF-8?q?Make=20=5F=E2=82=82F=E2=82=81=20inferred=20for?= =?UTF-8?q?=20Float32=20inputs=20by=20promoting=20the=20arguments.=20(#43)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make _₂F₁ inferred for Float32 inputs by promoting the arguments. * Bump the version number * Update test/runtests.jl Co-authored-by: David Widmann Co-authored-by: David Widmann --- Project.toml | 2 +- src/gauss.jl | 6 ++-- src/specialfunctions.jl | 76 ++++++++++++++++++++++++++++++++++------- test/runtests.jl | 3 ++ 4 files changed, 70 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 9ab8a84..8faf7bc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "HypergeometricFunctions" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.5" +version = "0.3.6" [deps] DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" diff --git a/src/gauss.jl b/src/gauss.jl index 007d644..3ce3c96 100644 --- a/src/gauss.jl +++ b/src/gauss.jl @@ -14,7 +14,7 @@ function _₂F₁(a, b, c, z) if isequal(a+b, 0) # 31. 15.4.11 & 15.4.12 return cosnasinsqrt(2b, z) elseif isequal(a+b, 1) # 32. 15.4.13 & 15.4.14 - return cosnasinsqrt(1-2b, z)*exp(-0.5log1p(-z)) + return cosnasinsqrt(1-2b, z)*exp(-log1p(-z)/2) elseif isequal(b-a, 0.5) # 15.4.7 & 15.4.8 return expnlog1pcoshatanhsqrt(-2a, z) end @@ -22,13 +22,13 @@ function _₂F₁(a, b, c, z) if abeqcd(a, b, 0.5) # 13. 15.4.4 & 15.4.5 return sqrtasinsqrt(z) elseif abeqcd(a, b, 1) # 14. - return sqrtasinsqrt(z)*exp(-0.5log1p(-z)) + return sqrtasinsqrt(z)*exp(-log1p(-z)/2) elseif abeqcd(a, b, 0.5, 1) # 15. 15.4.2 & 15.4.3 return sqrtatanhsqrt(z) elseif isequal(a+b, 1) # 29. 15.4.15 & 15.4.16 return sinnasinsqrt(1-2b, z) elseif isequal(a+b, 2) # 30. - return sinnasinsqrt(2-2b, z)*exp(-0.5log1p(-z)) + return sinnasinsqrt(2-2b, z)*exp(-log1p(-z)/2) elseif isequal(b-a, 0.5) # 4. 15.4.9 & 15.4.10 return expnlog1psinhatanhsqrt(1-2a, z) end diff --git a/src/specialfunctions.jl b/src/specialfunctions.jl index 8e2b24b..393819f 100644 --- a/src/specialfunctions.jl +++ b/src/specialfunctions.jl @@ -220,8 +220,8 @@ function G(z::Union{Float64, ComplexF64, Dual128, DualComplex256}, ϵ::Union{Flo end end -G(z::Number, ϵ::Number) = ϵ == 0 ? digamma(z)/unsafe_gamma(z) : (inv(unsafe_gamma(z))-inv(unsafe_gamma(z+ϵ)))/ϵ - +G(z::T, ϵ::T) where {T<:Number} = ϵ == 0 ? digamma(z)/unsafe_gamma(z) : (inv(unsafe_gamma(z))-inv(unsafe_gamma(z+ϵ)))/ϵ +G(z::Number, ϵ::Number) = G(promote(z, ϵ)...) """ Compute the function ((z+ϵ)ₘ-(z)ₘ)/ϵ @@ -271,9 +271,35 @@ G(z::AbstractVector{BigFloat}, ϵ::BigFloat) = BigFloat[G(zi, ϵ) for zi in z] # Transformation formula w = 1-z -reconeα₀(a, b, c, m::Int, ϵ) = ϵ == 0 ? (-1)^m*gamma(m)*gamma(c)/(gamma(a+m)*gamma(b+m)) : gamma(c)/(ϵ*gamma(1-m-ϵ)*gamma(a+m+ϵ)*gamma(b+m+ϵ)) -reconeβ₀(a, b, c, w, m::Int, ϵ) = abs(ϵ) > 0.1 ? ( pochhammer(float(a), m)*pochhammer(b, m)/(gamma(1-ϵ)*gamma(a+m+ϵ)*gamma(b+m+ϵ)*gamma(m+1)) - w^ϵ/(gamma(a)*gamma(b)*gamma(m+1+ϵ)) )*gamma(c)*w^m/ϵ : ( (G(1.0, -ϵ)/gamma(m+1)+G(m+1.0, ϵ))/(gamma(a+m+ϵ)*gamma(b+m+ϵ)) - (G(float(a)+m, ϵ)/gamma(b+m+ϵ)+G(float(b)+m, ϵ)/gamma(a+m))/gamma(m+1+ϵ) - E(log(w), ϵ)/(gamma(a+m)*gamma(b+m)*gamma(m+1+ϵ)) )*gamma(c)*pochhammer(float(a), m)*pochhammer(b, m)*w^m -reconeγ₀(a, b, c, w, m::Int, ϵ) = gamma(c)*pochhammer(float(a), m)*pochhammer(b, m)*w^m/(gamma(a+m+ϵ)*gamma(b+m+ϵ)*gamma(m+1)*gamma(1-ϵ)) +function reconeα₀(a, b, c, m::Int, ϵ) + _a, _b, _c, _ϵ = promote(a, b, c, ϵ) + return _reconeα₀(_a, _b, _c, m, _ϵ) +end +function _reconeα₀(a::T, b::T, c::T, m::Int, ϵ::T) where {T} + if ϵ == 0 + return (-1)^m*gamma(real(T)(m))*gamma(c)/(gamma(a+m)*gamma(b+m)) + else + return gamma(c)/(ϵ*gamma(1-m-ϵ)*gamma(a+m+ϵ)*gamma(b+m+ϵ)) + end +end +function reconeβ₀(a, b, c, w, m::Int, ϵ) + _a, _b, _c, _, _ϵ = promote(a, b, c, real(w), ϵ) + _w, _ = promote(w, zero(_a)) + return _reconeβ₀(_a, _b, _c, _w, m, _ϵ) +end +function _reconeβ₀(a::T, b::T, c::T, w::Number, m::Int, ϵ::T) where {T} + if abs(ϵ) > 0.1 + return ( pochhammer(a, m)*pochhammer(b, m)/(gamma(1-ϵ)*gamma(a+m+ϵ)*gamma(b+m+ϵ)*gamma(real(T)(m)+1)) - w^ϵ/(gamma(a)*gamma(b)*gamma(m+1+ϵ)) )*gamma(c)*w^m/ϵ + else + return ( (G(1, -ϵ)/gamma(real(T)(m)+1)+G(m+1, ϵ))/(gamma(a+m+ϵ)*gamma(b+m+ϵ)) - (G(a+m, ϵ)/gamma(b+m+ϵ)+G(float(b)+m, ϵ)/gamma(a+m))/gamma(m+1+ϵ) - E(log(w), ϵ)/(gamma(a+m)*gamma(b+m)*gamma(m+1+ϵ)) )*gamma(c)*pochhammer(a, m)*pochhammer(b, m)*w^m + end +end +function reconeγ₀(a, b, c, w, m::Int, ϵ) + _a, _b, _c, _, _ϵ = promote(a, b, c, real(w), ϵ) + _w, _ = promote(w, zero(_a)) + return _reconeγ₀(_a, _b, _c, _w, m, _ϵ) +end +_reconeγ₀(a::T, b::T, c::T, w::Number, m::Int, ϵ::T) where {T} = gamma(c)*pochhammer(a, m)*pochhammer(b, m)*w^m/(gamma(a+m+ϵ)*gamma(b+m+ϵ)*gamma(real(T)(m)+1)*gamma(1-ϵ)) function Aone(a, b, c, w, m::Int, ϵ) αₙ = reconeα₀(a, b, c, m, ϵ)*one(w) @@ -306,14 +332,38 @@ end # Transformation formula w = 1/z -recInfα₀(a, b, c, m::Int, ϵ) = ϵ == 0 ? (-1)^m*gamma(m)*gamma(c)/(gamma(a+m)*gamma(c-a)) : gamma(c)/(ϵ*gamma(1-m-ϵ)*gamma(a+m+ϵ)*gamma(c-a)) -recInfβ₀(a, b, c, w, m::Int, ϵ) = abs(ϵ) > 0.1 ? -( pochhammer(float(a), m)*pochhammer(float(1-c+a), m)/(gamma(1-ϵ)*gamma(a+m+ϵ)*gamma(c-a)*gamma(m+1)) - - (-w)^ϵ*pochhammer(float(1-c+a)+ϵ, m)/(gamma(a)*gamma(c-a-ϵ)*gamma(m+1+ϵ)) )*gamma(c)*w^m/ϵ : -( (pochhammer(float(1-c+a)+ϵ, m)*G(1.0, -ϵ)-P(1-c+a, ϵ, m)/gamma(1-ϵ))/(gamma(c-a)*gamma(a+m+ϵ)*gamma(m+1)) + - pochhammer(float(1-c+a)+ϵ, m)*( (G(m+1.0, ϵ)/gamma(a+m+ϵ) - G(float(a)+m, ϵ)/gamma(m+1+ϵ))/gamma(c-a) - -(G(float(c-a), -ϵ) - E(-log(-w), -ϵ)/gamma(c-a-ϵ))/(gamma(m+1+ϵ)*gamma(a+m)) ) )*gamma(c)*pochhammer(float(a), m)*w^m -recInfγ₀(a, b, c, w, m::Int, ϵ) = gamma(c)*pochhammer(float(a), m)*pochhammer(float(1-c+a), m)*w^m/(gamma(a+m+ϵ)*gamma(c-a)*gamma(m+1)*gamma(1-ϵ)) +function recInfα₀(a, b, c, m::Int, ϵ) + _a, _b, _c, _ϵ = promote(a, b, c, ϵ) + return _recInfα₀(_a, _b, _c, m, _ϵ) +end +function _recInfα₀(a::T, b::T, c::T, m::Int, ϵ::T) where {T} + if ϵ == 0 + return (-1)^m*gamma(real(T)(m))*gamma(c)/(gamma(a+m)*gamma(c-a)) + else + return gamma(c)/(ϵ*gamma(1-m-ϵ)*gamma(a+m+ϵ)*gamma(c-a)) + end +end +function recInfβ₀(a, b, c, w, m::Int, ϵ) + _a, _b, _c, _, _ϵ = promote(a, b, c, real(w), ϵ) + _w, _ = promote(w, zero(_a)) + return _recInfβ₀(_a, _b, _c, _w, m, _ϵ) +end +function _recInfβ₀(a::T, b::T, c::T, w::Number, m::Int, ϵ::T) where {T} + if abs(ϵ) > 0.1 + return ( pochhammer(a, m)*pochhammer(1-c+a, m)/(gamma(1-ϵ)*gamma(a+m+ϵ)*gamma(c-a)*gamma(real(T)(m)+1)) - + (-w)^ϵ*pochhammer(1-c+a+ϵ, m)/(gamma(a)*gamma(c-a-ϵ)*gamma(m+1+ϵ)) )*gamma(c)*w^m/ϵ + else + return ( (pochhammer(1-c+a+ϵ, m)*G(1, -ϵ)-P(1-c+a, ϵ, m)/gamma(1-ϵ))/(gamma(c-a)*gamma(a+m+ϵ)*gamma(real(T)(m)+1)) + + pochhammer(1-c+a+ϵ, m)*( (G(m+1, ϵ)/gamma(a+m+ϵ) - G(a+m, ϵ)/gamma(m+1+ϵ))/gamma(c-a) - + (G(c-a, -ϵ) - E(-log(-w), -ϵ)/gamma(c-a-ϵ))/(gamma(m+1+ϵ)*gamma(a+m)) ) )*gamma(c)*pochhammer(a, m)*w^m + end +end +function recInfγ₀(a, b, c, w, m::Int, ϵ) + _a, _b, _c, _, _ϵ = promote(a, b, c, real(w), ϵ) + _w, _ = promote(w, zero(_a)) + return _recInfγ₀(_a, _b, _c, _w, m, _ϵ) +end +_recInfγ₀(a::T, b::T, c::T, w::Number, m::Int, ϵ::T) where {T} = gamma(c)*pochhammer(a, m)*pochhammer(1-c+a, m)*w^m/(gamma(a+m+ϵ)*gamma(c-a)*gamma(real(T)(m)+1)*gamma(1-ϵ)) function AInf(a, b, c, w, m::Int, ϵ) αₙ = recInfα₀(a, b, c, m, ϵ)*one(w) diff --git a/test/runtests.jl b/test/runtests.jl index 6523a18..c209d35 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,9 @@ const NumberType = Float64 j += 1 end end + @testset "Test that _₂F₁ is inferred for Float32 arguments" begin + @test @inferred(_₂F₁(0.3f0, 0.7f0, 1.3f0, 0.1f0)) ≈ Float32(_₂F₁(0.3, 0.7, 1.3, 0.1)) + end end