Skip to content

Commit

Permalink
Merge pull request #8112 from eschnett/fma
Browse files Browse the repository at this point in the history
Implement fma
  • Loading branch information
jakebolewski committed Jan 19, 2015
2 parents 196544c + 2353834 commit 06e2137
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 4 deletions.
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ export
float32,
float64,
floor,
fma,
frexp,
gamma,
gcd,
Expand Down
3 changes: 3 additions & 0 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ widen(::Type{Float32}) = Float64
/(x::Float32, y::Float32) = box(Float32,div_float(unbox(Float32,x),unbox(Float32,y)))
/(x::Float64, y::Float64) = box(Float64,div_float(unbox(Float64,x),unbox(Float64,y)))

fma(x::Float32, y::Float32, z::Float32) = box(Float32,fma_float(unbox(Float32,x),unbox(Float32,y),unbox(Float32,z)))
fma(x::Float64, y::Float64, z::Float64) = box(Float64,fma_float(unbox(Float64,x),unbox(Float64,y),unbox(Float64,z)))

# TODO: faster floating point div?
# TODO: faster floating point fld?
# TODO: faster floating point mod?
Expand Down
3 changes: 3 additions & 0 deletions base/float16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ abs(x::Float16) = reinterpret(Float16, reinterpret(UInt16,x) & 0x7fff)
for op in (:+,:-,:*,:/,:\,:^)
@eval ($op)(a::Float16, b::Float16) = float16(($op)(float32(a), float32(b)))
end
function fma(a::Float16, b::Float16, c::Float16)
float16(fma(float32(a), float32(b), float32(c)))
end
for op in (:<,:<=,:isless)
@eval ($op)(a::Float16, b::Float16) = ($op)(float32(a), float32(b))
end
Expand Down
8 changes: 7 additions & 1 deletion base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export
import
Base: (*), +, -, /, <, <=, ==, >, >=, ^, besselj, besselj0, besselj1, bessely,
bessely0, bessely1, ceil, cmp, convert, copysign, deg2rad,
exp, exp2, exponent, factorial, floor, hypot, isinteger,
exp, exp2, exponent, factorial, floor, fma, hypot, isinteger,
isfinite, isinf, isnan, ldexp, log, log2, log10, max, min, mod, modf,
nextfloat, prevfloat, promote_rule, rad2deg, rem, round, show,
showcompact, sum, sqrt, string, print, trunc, precision, exp10, expm1,
Expand Down Expand Up @@ -306,6 +306,12 @@ function -(c::BigInt, x::BigFloat)
return z
end

function fma(x::BigFloat, y::BigFloat, z::BigFloat)
r = BigFloat()
ccall(("mpfr_fma",:libmpfr), Int32, (Ptr{BigFloat}, Ptr{BigFloat}, Ptr{BigFloat}, Ptr{BigFloat}, Int32), &r, &x, &y, &z, ROUNDING_MODE[end])
return r
end



# More efficient commutative operations
Expand Down
5 changes: 5 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ promote_to_super{T<:Number,S<:Number}(::Type{T}, ::Type{S}, ::Type) =
/(x::Number, y::Number) = /(promote(x,y)...)
^(x::Number, y::Number) = ^(promote(x,y)...)

fma(x::Number, y::Number, z::Number) = fma(promote(x,y,z)...)

(&)(x::Integer, y::Integer) = (&)(promote(x,y)...)
(|)(x::Integer, y::Integer) = (|)(promote(x,y)...)
($)(x::Integer, y::Integer) = ($)(promote(x,y)...)
Expand Down Expand Up @@ -191,6 +193,9 @@ no_op_err(name, T) = error(name," not defined for ",T)
/{T<:Number}(x::T, y::T) = no_op_err("/", T)
^{T<:Number}(x::T, y::T) = no_op_err("^", T)

fma{T<:Number}(x::T, y::T, z::T) = no_op_err("fma", T)
fma(x::Integer, y::Integer, z::Integer) = x*y+z

(&){T<:Integer}(x::T, y::T) = no_op_err("&", T)
(|){T<:Integer}(x::T, y::T) = no_op_err("|", T)
($){T<:Integer}(x::T, y::T) = no_op_err("\$", T)
Expand Down
2 changes: 2 additions & 0 deletions base/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ end
/(x::Rational, y::Rational) = x//y
/(x::Rational, z::Complex ) = inv(z/x)

fma(x::Rational, y::Rational, z::Rational) = x*y+z

==(x::Rational, y::Rational) = (x.den == y.den) & (x.num == y.num)
==(x::Rational, y::Integer ) = (x.den == 1) & (x.num == y)
==(x::Integer , y::Rational) = y == x
Expand Down
6 changes: 6 additions & 0 deletions doc/stdlib/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ Mathematical Operators

Element-wise exponentiation operator.

.. function:: fma(x, y, z)

Computes ``x*y+z`` without rounding the intermediate result
``x*y``. On some systems this is significantly more expensive than
``x*y+z``.

.. function:: div(x, y)
÷(x, y)

Expand Down
19 changes: 16 additions & 3 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace JL_I {
neg_int, add_int, sub_int, mul_int,
sdiv_int, udiv_int, srem_int, urem_int, smod_int,
neg_float, add_float, sub_float, mul_float, div_float, rem_float,
fma_float,
// fast arithmetic
neg_float_fast, add_float_fast, sub_float_fast,
mul_float_fast, div_float_fast, rem_float_fast,
Expand Down Expand Up @@ -900,9 +901,13 @@ static Value *emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
if (nargs>1) {
y = auto_unbox(args[2], ctx);
}
Value *z = NULL;
if (nargs>2) {
z = auto_unbox(args[3], ctx);
}
Type *t = x->getType();
if (t == T_void || (y && y->getType() == T_void))
return t == T_void ? x : y;
if (t == T_void || (y && y->getType() == T_void) || (z && z->getType() == T_void))
return t == T_void ? x : y->getType() == T_void ? y : z;

Value *fy;
Value *den;
Expand Down Expand Up @@ -973,6 +978,14 @@ static Value *emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
HANDLE(mul_float_fast,2) return math_builder(ctx, true)().CreateFMul(FP(x), FP(y));
HANDLE(div_float_fast,2) return math_builder(ctx, true)().CreateFDiv(FP(x), FP(y));
HANDLE(rem_float_fast,2) return math_builder(ctx, true)().CreateFRem(FP(x), FP(y));
HANDLE(fma_float,3) {
assert(y->getType() == x->getType());
assert(z->getType() == y->getType());
return builder.CreateCall3
(Intrinsic::getDeclaration(jl_Module, Intrinsic::fma,
ArrayRef<Type*>(x->getType())),
FP(x), FP(y), FP(z));
}

HANDLE(checked_sadd,2)
HANDLE(checked_uadd,2)
Expand Down Expand Up @@ -1310,7 +1323,7 @@ extern "C" void jl_init_intrinsic_functions(void)
ADD_I(sdiv_int); ADD_I(udiv_int); ADD_I(srem_int); ADD_I(urem_int);
ADD_I(smod_int);
ADD_I(neg_float); ADD_I(add_float); ADD_I(sub_float); ADD_I(mul_float);
ADD_I(div_float); ADD_I(rem_float);
ADD_I(div_float); ADD_I(rem_float); ADD_I(fma_float);
ADD_I(neg_float_fast); ADD_I(add_float_fast); ADD_I(sub_float_fast);
ADD_I(mul_float_fast); ADD_I(div_float_fast); ADD_I(rem_float_fast);
ADD_I(eq_int); ADD_I(ne_int);
Expand Down
46 changes: 46 additions & 0 deletions test/numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,52 @@
@test minmax(NaN, 3.) == (3., 3.)
@test isequal(minmax(NaN, NaN), (NaN, NaN))

# fma
let x = int64(7)^7
@test fma(x-1, x-2, x-3) == (x-1) * (x-2) + (x-3)
@test (fma((x-1)//(x-2), (x-3)//(x-4), (x-5)//(x-6)) ==
(x-1)//(x-2) * (x-3)//(x-4) + (x-5)//(x-6))
end

let x = BigInt(7)^77
@test fma(x-1, x-2, x-3) == (x-1) * (x-2) + (x-3)
@test (fma((x-1)//(x-2), (x-3)//(x-4), (x-5)//(x-6)) ==
(x-1)//(x-2) * (x-3)//(x-4) + (x-5)//(x-6))
end

let eps = 1//BigInt(2)^30, one_eps = 1+eps,
eps64 = float64(eps), one_eps64 = float64(one_eps)
@test eps64 == float64(eps)
@test one_eps64 == float64(one_eps)
@test one_eps64 * one_eps64 - 1 != float64(one_eps * one_eps - 1)
@test fma(one_eps64, one_eps64, -1) == float64(one_eps * one_eps - 1)
end

let eps = 1//BigInt(2)^15, one_eps = 1+eps,
eps32 = float32(eps), one_eps32 = float32(one_eps)
@test eps32 == float32(eps)
@test one_eps32 == float32(one_eps)
@test one_eps32 * one_eps32 - 1 != float32(one_eps * one_eps - 1)
@test fma(one_eps32, one_eps32, -1) == float32(one_eps * one_eps - 1)
end

let eps = 1//BigInt(2)^7, one_eps = 1+eps,
eps16 = float16(float32(eps)), one_eps16 = float16(float32(one_eps))
@test eps16 == float16(float32(eps))
@test one_eps16 == float16(float32(one_eps))
@test one_eps16 * one_eps16 - 1 != float16(float32(one_eps * one_eps - 1))
@test (fma(one_eps16, one_eps16, -1) ==
float16(float32(one_eps * one_eps - 1)))
end

let eps = 1//BigInt(2)^200, one_eps = 1+eps,
eps256 = BigFloat(eps), one_eps256 = BigFloat(one_eps)
@test eps256 == BigFloat(eps)
@test one_eps256 == BigFloat(one_eps)
@test one_eps256 * one_eps256 - 1 != BigFloat(one_eps * one_eps - 1)
@test fma(one_eps256, one_eps256, -1) == BigFloat(one_eps * one_eps - 1)
end

# lexing typemin(Int64)
@test (-9223372036854775808)^1 == -9223372036854775808
@test [1 -1 -9223372036854775808] == [1 -1 typemin(Int64)]
Expand Down

0 comments on commit 06e2137

Please sign in to comment.