Skip to content

Commit

Permalink
Merge pull request #21589 from JuliaLang/yyc/cis
Browse files Browse the repository at this point in the history
Use `sincos` from libm in `cis`
  • Loading branch information
yuyichao authored May 25, 2017
2 parents 888adbc + 4ed22b1 commit bac32d3
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 17 deletions.
29 changes: 19 additions & 10 deletions base/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,10 @@ sqrt(z::Complex) = sqrt(float(z))
# end

# compute exp(im*theta)
cis(theta::Real) = Complex(cos(theta),sin(theta))
function cis(theta::Real)
s, c = sincos(theta)
Complex(c, s)
end

"""
cis(z)
Expand All @@ -433,7 +436,8 @@ Return ``\\exp(iz)``.
"""
function cis(z::Complex)
v = exp(-imag(z))
Complex(v*cos(real(z)), v*sin(real(z)))
s, c = sincos(real(z))
Complex(v * c, v * s)
end

"""
Expand Down Expand Up @@ -510,7 +514,8 @@ function exp(z::Complex)
if iszero(zi)
Complex(er, zi)
else
Complex(er*cos(zi), er*sin(zi))
s, c = sincos(zi)
Complex(er * c, er * s)
end
end
end
Expand Down Expand Up @@ -538,7 +543,8 @@ function expm1(z::Complex{T}) where T<:Real
wr = erm1 - 2 * er * (sin(convert(Tf, 0.5) * zi))^2
return Complex(wr, er * sin(zi))
else
return Complex(er * cos(zi), er * sin(zi))
s, c = sincos(zi)
return Complex(er * c, er * s)
end
end
end
Expand Down Expand Up @@ -600,13 +606,15 @@ end
function exp2(z::Complex{T}) where T
er = exp2(real(z))
theta = imag(z) * log(convert(T, 2))
Complex(er*cos(theta), er*sin(theta))
s, c = sincos(theta)
Complex(er * c, er * s)
end

function exp10(z::Complex{T}) where T
er = exp10(real(z))
theta = imag(z) * log(convert(T, 10))
Complex(er*cos(theta), er*sin(theta))
s, c = sincos(theta)
Complex(er * c, er * s)
end

function ^(z::T, p::T) where T<:Complex
Expand All @@ -628,8 +636,7 @@ function ^(z::T, p::T) where T<:Complex
rp = rp*exp(-pim*theta)
ntheta = ntheta + pim*log(r)
end
cosntheta = cos(ntheta)
sinntheta = sin(ntheta)
sinntheta, cosntheta = sincos(ntheta)
re, im = rp*cosntheta, rp*sinntheta
if isinf(rp)
if isnan(re)
Expand Down Expand Up @@ -689,7 +696,8 @@ function sin(z::Complex{T}) where T
Complex(F(NaN), F(NaN))
end
else
Complex(sin(zr)*cosh(zi), cos(zr)*sinh(zi))
s, c = sincos(zr)
Complex(s * cosh(zi), c * sinh(zi))
end
end

Expand All @@ -708,7 +716,8 @@ function cos(z::Complex{T}) where T
Complex(F(NaN), F(NaN))
end
else
Complex(cos(zr)*cosh(zi), -sin(zr)*sinh(zi))
s, c = sincos(zr)
Complex(c * cosh(zi), -s * sinh(zi))
end
end

Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ export
significand,
sin,
sinc,
sincos,
sind,
sinh,
sinpi,
Expand Down
45 changes: 44 additions & 1 deletion base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ const fast_op =
:min => :min_fast,
:minmax => :minmax_fast,
:sin => :sin_fast,
:sincos => :sincos_fast,
:sinh => :sinh_fast,
:sqrt => :sqrt_fast,
:tan => :tan_fast,
Expand Down Expand Up @@ -273,6 +274,45 @@ atan2_fast(x::Float64, y::Float64) =

# explicit implementations

# FIXME: Change to `ccall((:sincos, libm))` when `Ref` calling convention can be
# stack allocated.
@inline function sincos_fast(v::Float64)
return Base.llvmcall("""
%f = bitcast i8 *%1 to void (double, double *, double *)*
%ps = alloca double
%pc = alloca double
call void %f(double %0, double *%ps, double *%pc)
%s = load double, double* %ps
%c = load double, double* %pc
%res0 = insertvalue [2 x double] undef, double %s, 0
%res = insertvalue [2 x double] %res0, double %c, 1
ret [2 x double] %res
""", Tuple{Float64,Float64}, Tuple{Float64,Ptr{Void}}, v, cglobal((:sincos, libm)))
end

@inline function sincos_fast(v::Float32)
return Base.llvmcall("""
%f = bitcast i8 *%1 to void (float, float *, float *)*
%ps = alloca float
%pc = alloca float
call void %f(float %0, float *%ps, float *%pc)
%s = load float, float* %ps
%c = load float, float* %pc
%res0 = insertvalue [2 x float] undef, float %s, 0
%res = insertvalue [2 x float] %res0, float %c, 1
ret [2 x float] %res
""", Tuple{Float32,Float32}, Tuple{Float32,Ptr{Void}}, v, cglobal((:sincosf, libm)))
end

@inline function sincos_fast(v::Float16)
s, c = sincos_fast(Float32(v))
return Float16(s), Float16(c)
end

sincos_fast(v::AbstractFloat) = (sin_fast(v), cos_fast(v))
sincos_fast(v::Real) = sincos_fast(float(v)::AbstractFloat)
sincos_fast(v) = (sin_fast(v), cos_fast(v))

@fastmath begin
exp10_fast(x::T) where {T<:FloatTypes} = exp2(log2(T(10))*x)
exp10_fast(x::Integer) = exp10(float(x))
Expand All @@ -287,7 +327,10 @@ atan2_fast(x::Float64, y::Float64) =

# complex numbers

cis_fast(x::T) where {T<:FloatTypes} = Complex{T}(cos(x), sin(x))
function cis_fast(x::T) where {T<:FloatTypes}
s, c = sincos_fast(x)
Complex{T}(c, s)
end

# See <http://en.cppreference.com/w/cpp/numeric/complex>
pow_fast(x::T, y::T) where {T<:ComplexTypes} = exp(y*log(x))
Expand Down
15 changes: 14 additions & 1 deletion base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module Math

export sin, cos, tan, sinh, cosh, tanh, asin, acos, atan,
export sin, cos, sincos, tan, sinh, cosh, tanh, asin, acos, atan,
asinh, acosh, atanh, sec, csc, cot, asec, acsc, acot,
sech, csch, coth, asech, acsch, acoth,
sinpi, cospi, sinc, cosc,
Expand Down Expand Up @@ -419,6 +419,19 @@ for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
end
end

"""
sincos(x)
Compute sine and cosine of `x`, where `x` is in radians.
"""
@inline function sincos(x)
res = Base.FastMath.sincos_fast(x)
if (isnan(res[1]) | isnan(res[2])) & !isnan(x)
throw(DomainError())
end
return res
end

sqrt(x::Float64) = sqrt_llvm(x)
sqrt(x::Float32) = sqrt_llvm(x)

Expand Down
13 changes: 12 additions & 1 deletion base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import
nextfloat, prevfloat, promote_rule, rem, rem2pi, round, show, float,
sum, sqrt, string, print, trunc, precision, exp10, expm1,
gamma, lgamma, log1p,
eps, signbit, sin, cos, tan, sec, csc, cot, acos, asin, atan,
eps, signbit, sin, cos, sincos, tan, sec, csc, cot, acos, asin, atan,
cosh, sinh, tanh, sech, csch, coth, acosh, asinh, atanh, atan2,
cbrt, typemax, typemin, unsafe_trunc, realmin, realmax, rounding,
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero, big
Expand All @@ -24,6 +24,8 @@ import Base.GMP: ClongMax, CulongMax, CdoubleMax, Limb

import Base.Math.lgamma_r

import Base.FastMath.sincos_fast

function __init__()
try
# set exponent to full range by default
Expand Down Expand Up @@ -515,6 +517,15 @@ for f in (:exp, :exp2, :exp10, :expm1, :cosh, :sinh, :tanh, :sech, :csch, :coth,
end
end

function sincos_fast(v::BigFloat)
s = BigFloat()
c = BigFloat()
ccall((:mpfr_sin_cos, :libmpfr), Int32, (Ptr{BigFloat}, Ptr{BigFloat}, Ptr{BigFloat}, Int32),
&s, &c, &v, ROUNDING_MODE[])
return (s, c)
end
sincos(v::BigFloat) = sincos_fast(v)

# return log(2)
function big_ln2()
c = BigFloat()
Expand Down
8 changes: 4 additions & 4 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ importall .Order
include("sort.jl")
importall .Sort

# Fast math
include("fastmath.jl")
importall .FastMath

function deepcopy_internal end

# BigInts and BigFloats
Expand Down Expand Up @@ -362,10 +366,6 @@ importall .DFT
include("dsp.jl")
importall .DSP

# Fast math
include("fastmath.jl")
importall .FastMath

# libgit2 support
include("libgit2/libgit2.jl")

Expand Down
1 change: 1 addition & 0 deletions doc/src/stdlib/math.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Base.:(!)
Base.isapprox
Base.sin
Base.cos
Base.sincos
Base.tan
Base.Math.sind
Base.Math.cosd
Expand Down
1 change: 1 addition & 0 deletions test/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
@test macroexpand(:(@fastmath min(1))) == :(Base.FastMath.min_fast(1))
@test macroexpand(:(@fastmath min)) == :(Base.FastMath.min_fast)
@test macroexpand(:(@fastmath x.min)) == :(x.min)
@test macroexpand(:(@fastmath sincos(x))) == :(Base.FastMath.sincos_fast(x))

# basic arithmetic

Expand Down
9 changes: 9 additions & 0 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,12 @@ end
@testset "promote Float16 irrational #15359" begin
@test typeof(Float16(.5) * pi) == Float16
end

@testset "sincos" begin
@test sincos(1.0) === (sin(1.0), cos(1.0))
@test sincos(1f0) === (sin(1f0), cos(1f0))
@test sincos(Float16(1)) === (sin(Float16(1)), cos(Float16(1)))
@test sincos(1) === (sin(1), cos(1))
@test sincos(big(1)) == (sin(big(1)), cos(big(1)))
@test sincos(big(1.0)) == (sin(big(1.0)), cos(big(1.0)))
end

0 comments on commit bac32d3

Please sign in to comment.