Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

math.hypot: fix incorrect over/underflow behavior #19472

Merged
merged 10 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/std/math.zig
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub const floatTrueMin = @import("math/float.zig").floatTrueMin;
pub const floatMin = @import("math/float.zig").floatMin;
pub const floatMax = @import("math/float.zig").floatMax;
pub const floatEps = @import("math/float.zig").floatEps;
pub const floatEpsAt = @import("math/float.zig").floatEpsAt;
pub const inf = @import("math/float.zig").inf;
pub const nan = @import("math/float.zig").nan;
pub const snan = @import("math/float.zig").snan;
Expand Down
13 changes: 13 additions & 0 deletions lib/std/math/float.zig
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ pub inline fn floatEps(comptime T: type) T {
return reconstructFloat(T, -floatFractionalBits(T), mantissaOne(T));
}

/// Returns the local epsilon of floating point type T.
pub inline fn floatEpsAt(comptime T: type, x: T) T {
switch (@typeInfo(T)) {
.Float => |F| {
const U: type = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = F.bits } });
const u: U = @bitCast(x);
const y: T = @bitCast(u ^ 1);
return @abs(x - y);
},
else => @compileError("floatEpsAt only supports floats"),
}
}

/// Returns the value inf for floating point type T.
pub inline fn inf(comptime T: type) T {
return reconstructFloat(T, floatExponentMax(T) + 1, mantissaOne(T));
Expand Down
243 changes: 99 additions & 144 deletions lib/std/math/hypot.zig
Original file line number Diff line number Diff line change
@@ -1,176 +1,131 @@
// Ported from musl, which is licensed under the MIT license:
// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
//
// https://git.musl-libc.org/cgit/musl/tree/src/math/hypotf.c
// https://git.musl-libc.org/cgit/musl/tree/src/math/hypot.c

const std = @import("../std.zig");
const math = std.math;
const expect = std.testing.expect;
const maxInt = std.math.maxInt;
const isNan = math.isNan;
const isInf = math.isInf;
const inf = math.inf;
const nan = math.nan;
const floatEpsAt = math.floatEpsAt;
const floatEps = math.floatEps;
const floatMin = math.floatMin;
const floatMax = math.floatMax;

/// Returns sqrt(x * x + y * y), avoiding unnecessary overflow and underflow.
///
/// Special Cases:
///
/// | x | y | hypot |
/// |-------|-------|-------|
/// | +inf | num | +inf |
/// | num | +-inf | +inf |
/// | nan | any | nan |
/// | any | nan | nan |
/// | +-inf | any | +inf |
/// | any | +-inf | +inf |
/// | nan | fin | nan |
/// | fin | nan | nan |
pub fn hypot(x: anytype, y: anytype) @TypeOf(x, y) {
const T = @TypeOf(x, y);
return switch (T) {
f32 => hypot32(x, y),
f64 => hypot64(x, y),
switch (@typeInfo(T)) {
.Float => {},
.ComptimeFloat => return @sqrt(x * x + y * y),
else => @compileError("hypot not implemented for " ++ @typeName(T)),
};
}

fn hypot32(x: f32, y: f32) f32 {
var ux = @as(u32, @bitCast(x));
var uy = @as(u32, @bitCast(y));

ux &= maxInt(u32) >> 1;
uy &= maxInt(u32) >> 1;
if (ux < uy) {
const tmp = ux;
ux = uy;
uy = tmp;
}

var xx = @as(f32, @bitCast(ux));
var yy = @as(f32, @bitCast(uy));
if (uy == 0xFF << 23) {
return yy;
}
if (ux >= 0xFF << 23 or uy == 0 or ux - uy >= (25 << 23)) {
return xx + yy;
}

var z: f32 = 1.0;
if (ux >= (0x7F + 60) << 23) {
z = 0x1.0p90;
xx *= 0x1.0p-90;
yy *= 0x1.0p-90;
} else if (uy < (0x7F - 60) << 23) {
z = 0x1.0p-90;
xx *= 0x1.0p-90;
yy *= 0x1.0p-90;
const lower = @sqrt(floatMin(T));
const upper = @sqrt(floatMax(T) / 2);
const incre = @sqrt(floatEps(T) / 2);
const scale = floatEpsAt(T, incre);
const hypfn = if (emulateFma(T)) hypotUnfused else hypotFused;
var major: T = x;
var minor: T = y;
if (isInf(major) or isInf(minor)) return inf(T);
if (isNan(major) or isNan(minor)) return nan(T);
if (T == f16) return @floatCast(@sqrt(@mulAdd(f32, x, x, @as(f32, y) * y)));
if (T == f32) return @floatCast(@sqrt(@mulAdd(f64, x, x, @as(f64, y) * y)));
major = @abs(major);
minor = @abs(minor);
if (minor > major) {
const tempo = major;
major = minor;
minor = tempo;
}

return z * @sqrt(@as(f32, @floatCast(@as(f64, x) * x + @as(f64, y) * y)));
if (major * incre >= minor) return major;
if (major > upper) return hypfn(T, major * scale, minor * scale) / scale;
if (minor < lower) return hypfn(T, major / scale, minor / scale) * scale;
return hypfn(T, major, minor);
}

fn sq(hi: *f64, lo: *f64, x: f64) void {
const split: f64 = 0x1.0p27 + 1.0;
const xc = x * split;
const xh = x - xc + xc;
const xl = x - xh;
hi.* = x * x;
lo.* = xh * xh - hi.* + 2 * xh * xl + xl * xl;
inline fn emulateFma(comptime T: type) bool {
// If @mulAdd lowers to the software implementation,
// hypotUnfused should be used in place of hypotFused.
// This takes an educated guess, but ideally we should
// properly detect at comptime when that fallback will
// occur.
return (T == f128 or T == f80);
}

fn hypot64(x: f64, y: f64) f64 {
var ux = @as(u64, @bitCast(x));
var uy = @as(u64, @bitCast(y));

ux &= maxInt(u64) >> 1;
uy &= maxInt(u64) >> 1;
if (ux < uy) {
const tmp = ux;
ux = uy;
uy = tmp;
}

const ex = ux >> 52;
const ey = uy >> 52;
var xx = @as(f64, @bitCast(ux));
var yy = @as(f64, @bitCast(uy));

// hypot(inf, nan) == inf
if (ey == 0x7FF) {
return yy;
}
if (ex == 0x7FF or uy == 0) {
return xx;
}

// hypot(x, y) ~= x + y * y / x / 2 with inexact for small y/x
if (ex - ey > 64) {
return xx + yy;
}
inline fn hypotFused(comptime F: type, x: F, y: F) F {
const r = @sqrt(@mulAdd(F, x, x, y * y));
const rr = r * r;
const xx = x * x;
const z = @mulAdd(F, -y, y, rr - xx) + @mulAdd(F, r, r, -rr) - @mulAdd(F, x, x, -xx);
return r - z / (2 * r);
}

var z: f64 = 1;
if (ex > 0x3FF + 510) {
z = 0x1.0p700;
xx *= 0x1.0p-700;
yy *= 0x1.0p-700;
} else if (ey < 0x3FF - 450) {
z = 0x1.0p-700;
xx *= 0x1.0p700;
yy *= 0x1.0p700;
inline fn hypotUnfused(comptime F: type, x: F, y: F) F {
expikr marked this conversation as resolved.
Show resolved Hide resolved
const r = @sqrt(x * x + y * y);
if (r <= 2 * y) { // 30deg or steeper
const dx = r - y;
const z = x * (2 * dx - x) + (dx - 2 * (x - y)) * dx;
return r - z / (2 * r);
} else { // shallower than 30 deg
const dy = r - x;
const z = 2 * dy * (x - 2 * y) + (4 * dy - y) * y + dy * dy;
return r - z / (2 * r);
}

var hx: f64 = undefined;
var lx: f64 = undefined;
var hy: f64 = undefined;
var ly: f64 = undefined;

sq(&hx, &lx, x);
sq(&hy, &ly, y);

return z * @sqrt(ly + lx + hy + hx);
}

const hypot_test_cases = .{
.{ 0.0, -1.2, 1.2 },
.{ 0.2, -0.34, 0.3944616584663203993612799816649560759946493601889826495362 },
.{ 0.8923, 2.636890, 2.7837722899152509525110650481670176852603253522923737962880 },
.{ 1.5, 5.25, 5.4600824169603887033229768686452745953332522619323580787836 },
.{ 37.45, 159.835, 164.16372840856167640478217141034363907565754072954443805164 },
.{ 89.123, 382.028905, 392.28687638576315875933966414927490685367196874260165618371 },
.{ 123123.234375, 529428.707813, 543556.88524707706887251269205923830745438413088753096759371 },
};

test hypot {
const x32: f32 = 0.0;
const y32: f32 = -1.2;
const x64: f64 = 0.0;
const y64: f64 = -1.2;
try expect(hypot(x32, y32) == hypot32(0.0, -1.2));
try expect(hypot(x64, y64) == hypot64(0.0, -1.2));
try expect(hypot(0.3, 0.4) == 0.5);
}

test hypot32 {
const epsilon = 0.000001;

try expect(math.approxEqAbs(f32, hypot32(0.0, -1.2), 1.2, epsilon));
try expect(math.approxEqAbs(f32, hypot32(0.2, -0.34), 0.394462, epsilon));
try expect(math.approxEqAbs(f32, hypot32(0.8923, 2.636890), 2.783772, epsilon));
try expect(math.approxEqAbs(f32, hypot32(1.5, 5.25), 5.460083, epsilon));
try expect(math.approxEqAbs(f32, hypot32(37.45, 159.835), 164.163742, epsilon));
try expect(math.approxEqAbs(f32, hypot32(89.123, 382.028905), 392.286865, epsilon));
try expect(math.approxEqAbs(f32, hypot32(123123.234375, 529428.707813), 543556.875, epsilon));
test "hypot.correct" {
inline for (.{ f16, f32, f64, f128 }) |T| {
inline for (hypot_test_cases) |v| {
const a: T, const b: T, const c: T = v;
try expect(math.approxEqRel(T, hypot(a, b), c, @sqrt(floatEps(T))));
}
}
}

test hypot64 {
const epsilon = 0.000001;

try expect(math.approxEqAbs(f64, hypot64(0.0, -1.2), 1.2, epsilon));
try expect(math.approxEqAbs(f64, hypot64(0.2, -0.34), 0.394462, epsilon));
try expect(math.approxEqAbs(f64, hypot64(0.8923, 2.636890), 2.783772, epsilon));
try expect(math.approxEqAbs(f64, hypot64(1.5, 5.25), 5.460082, epsilon));
try expect(math.approxEqAbs(f64, hypot64(37.45, 159.835), 164.163728, epsilon));
try expect(math.approxEqAbs(f64, hypot64(89.123, 382.028905), 392.286876, epsilon));
try expect(math.approxEqAbs(f64, hypot64(123123.234375, 529428.707813), 543556.885247, epsilon));
test "hypot.precise" {
inline for (.{ f16, f32, f64 }) |T| { // f128 seems to be 5 ulp
inline for (hypot_test_cases) |v| {
const a: T, const b: T, const c: T = v;
try expect(math.approxEqRel(T, hypot(a, b), c, floatEps(T)));
}
}
}

test "hypot32.special" {
try expect(math.isPositiveInf(hypot32(math.inf(f32), 0.0)));
try expect(math.isPositiveInf(hypot32(-math.inf(f32), 0.0)));
try expect(math.isPositiveInf(hypot32(0.0, math.inf(f32))));
try expect(math.isPositiveInf(hypot32(0.0, -math.inf(f32))));
try expect(math.isNan(hypot32(math.nan(f32), 0.0)));
try expect(math.isNan(hypot32(0.0, math.nan(f32))));
}
test "hypot.special" {
inline for (.{ f16, f32, f64, f128 }) |T| {
try expect(math.isNan(hypot(nan(T), 0.0)));
try expect(math.isNan(hypot(0.0, nan(T))));

try expect(math.isPositiveInf(hypot(inf(T), 0.0)));
try expect(math.isPositiveInf(hypot(0.0, inf(T))));
try expect(math.isPositiveInf(hypot(inf(T), nan(T))));
try expect(math.isPositiveInf(hypot(nan(T), inf(T))));

test "hypot64.special" {
try expect(math.isPositiveInf(hypot64(math.inf(f64), 0.0)));
try expect(math.isPositiveInf(hypot64(-math.inf(f64), 0.0)));
try expect(math.isPositiveInf(hypot64(0.0, math.inf(f64))));
try expect(math.isPositiveInf(hypot64(0.0, -math.inf(f64))));
try expect(math.isNan(hypot64(math.nan(f64), 0.0)));
try expect(math.isNan(hypot64(0.0, math.nan(f64))));
try expect(math.isPositiveInf(hypot(-inf(T), 0.0)));
try expect(math.isPositiveInf(hypot(0.0, -inf(T))));
try expect(math.isPositiveInf(hypot(-inf(T), nan(T))));
try expect(math.isPositiveInf(hypot(nan(T), -inf(T))));
}
}