Skip to content

Commit

Permalink
inference: fix #41450, replace constant-folded muladd with fma
Browse files Browse the repository at this point in the history
So that we can get more accurate results when performance doesn't matter
  • Loading branch information
aviatesk committed Jul 13, 2021
1 parent e422467 commit 65e3e7b
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 6 deletions.
4 changes: 3 additions & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ using .Sort
something(x::Nothing, y...) = something(y...)
something(x::Any, y...) = x

const ARCH = ccall(:jl_get_ARCH, Any, ())
include("build_h.jl")

############
# compiler #
############
Expand Down Expand Up @@ -142,4 +145,3 @@ Core.eval(Core, :(_parse = Compiler.fl_parse))

end # baremodule Compiler
))

27 changes: 27 additions & 0 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,24 @@ function builtin_nothrow(@nospecialize(f), argtypes::Array{Any, 1}, @nospecializ
return _builtin_nothrow(f, argtypes, rt)
end

# NOTE sync the definitions below with those defined in floatfuncs.jl
fma_libm(x::Float32, y::Float32, z::Float32) =
ccall(("fmaf", libm_name), Float32, (Float32,Float32,Float32), x, y, z)
fma_libm(x::Float64, y::Float64, z::Float64) =
ccall(("fma", libm_name), Float64, (Float64,Float64,Float64), x, y, z)
fma_llvm(x::Float32, y::Float32, z::Float32) = fma_float(x, y, z)
fma_llvm(x::Float64, y::Float64, z::Float64) = fma_float(x, y, z)
if ARCH !== :i686 &&
fma_llvm(1.0000305f0, 1.0000305f0, -1.0f0) == 6.103609f-5 &&
fma_llvm(1.0000000009313226, 1.0000000009313226, -1.0) == 1.8626451500983188e-9 &&
add_float(0.1, 0.2) == 0.30000000000000004
fma(x::Float32, y::Float32, z::Float32) = fma_llvm(x,y,z)
fma(x::Float64, y::Float64, z::Float64) = fma_llvm(x,y,z)
else
fma(x::Float32, y::Float32, z::Float32) = fma_libm(x,y,z)
fma(x::Float64, y::Float64, z::Float64) = fma_libm(x,y,z)
end

function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Array{Any,1},
sv::Union{InferenceState,Nothing})
if f === tuple
Expand All @@ -1522,6 +1540,15 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
if isa(f, IntrinsicFunction)
if is_pure_intrinsic_infer(f) && _all(@nospecialize(a) -> isa(a, Const), argtypes)
argvals = anymap(a::Const -> a.val, argtypes)
# https://github.com/JuliaLang/julia/issues/41450
# enforce no rounding for better accuracy
if f === muladd_float && length(argvals) == 3
if _all(@nospecialize(a) -> isa(a, Float32), argvals)
return Const(fma(argvals[1]::Float32, argvals[2]::Float32, argvals[3]::Float32))
elseif _all(@nospecialize(a) -> isa(a, Float64), argvals)
return Const(fma(argvals[1]::Float64, argvals[2]::Float64, argvals[3]::Float64))
end
end
try
return Const(f(argvals...))
catch
Expand Down
8 changes: 5 additions & 3 deletions base/floatfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,11 @@ fma_llvm(x::Float64, y::Float64, z::Float64) = fma_float(x, y, z)
# 1.0000000009313226 = 1 + 1/2^30
# If fma_llvm() clobbers the rounding mode, the result of 0.1 + 0.2 will be 0.3
# instead of the properly-rounded 0.30000000000000004; check after calling fma
if (Sys.ARCH !== :i686 && fma_llvm(1.0000305f0, 1.0000305f0, -1.0f0) == 6.103609f-5 &&
(fma_llvm(1.0000000009313226, 1.0000000009313226, -1.0) ==
1.8626451500983188e-9) && 0.1 + 0.2 == 0.30000000000000004)
# NOTE this system level check is also used in compiler/tfuncs.jl, make sure to sync them
if Sys.ARCH !== :i686 &&
fma_llvm(1.0000305f0, 1.0000305f0, -1.0f0) == 6.103609f-5 &&
fma_llvm(1.0000000009313226, 1.0000000009313226, -1.0) == 1.8626451500983188e-9 &&
0.1 + 0.2 == 0.30000000000000004
fma(x::Float32, y::Float32, z::Float32) = fma_llvm(x,y,z)
fma(x::Float64, y::Float64, z::Float64) = fma_llvm(x,y,z)
else
Expand Down
6 changes: 6 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3395,3 +3395,9 @@ end
x.x
end) == Any[Int]
end

# https://github.com/JuliaLang/julia/issues/41450
@test (@eval Module() begin
foo(x=1.0) = muladd(1 + eps(x), 1 - eps(x), -1)
foo() == foo(1.0)
end)
3 changes: 1 addition & 2 deletions test/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,7 @@ end
@test log(x) == log(42)
@test isinf(log(BigFloat(0)))
@test_throws DomainError log(BigFloat(-1))
# issue #41450
@test_skip log2(x) == log2(42)
@test log2(x) == log2(42)
@test isinf(log2(BigFloat(0)))
@test_throws DomainError log2(BigFloat(-1))
@test log10(x) == log10(42)
Expand Down

0 comments on commit 65e3e7b

Please sign in to comment.