Skip to content

Commit

Permalink
fix(stdlib): Fix NaN comparisons (#1543)
Browse files Browse the repository at this point in the history
feat(runtime): Optimize simple number comparison
  • Loading branch information
ospencer committed Dec 10, 2022
1 parent eeb2eaa commit f7ceae7
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 55 deletions.
12 changes: 12 additions & 0 deletions compiler/test/stdlib/pervasives.test.gr
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ assert compare(0.0 / 0.0, -1 / 0.0) < 0
assert compare(0.0 / 0.0, 987654321.) < 0
assert compare(0.0 / 0.0, 987654321) < 0
assert compare(0.0 / 0.0, 0) < 0
assert !(0.0 / 0.0 < 0.0 / 0.0)
assert !(0.0 / 0.0 < 10)
assert !(0.0 / 0.0 < 10.)
assert !(0.0 / 0.0 <= 0.0 / 0.0)
assert !(0.0 / 0.0 <= 10)
assert !(0.0 / 0.0 <= 10.)
assert !(0.0 / 0.0 > 0.0 / 0.0)
assert !(0.0 / 0.0 > 10)
assert !(0.0 / 0.0 > 10.)
assert !(0.0 / 0.0 >= 0.0 / 0.0)
assert !(0.0 / 0.0 >= 10)
assert !(0.0 / 0.0 >= 10.)
// Booleans
assert compare(false, true) < 0
assert compare(true, false) > 0
Expand Down
24 changes: 3 additions & 21 deletions stdlib/number.gr
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
isInteger,
isRational,
isBoxedNumber,
isNaN,
scalbn,
} from "runtime/numbers"
import Atoi from "runtime/atoi/parse"
Expand Down Expand Up @@ -814,8 +815,7 @@ export let isFinite = (x: Number) => {
}

/**
* Checks if a number contains the NaN value (Not A Number).
* Only boxed floating point numbers can contain NaN.
* Checks if a number is the float NaN value (Not A Number).
*
* @param x: The number to check
* @returns `true` if the value is NaN, otherwise `false`
Expand All @@ -825,25 +825,7 @@ export let isFinite = (x: Number) => {
@unsafe
export let isNaN = (x: Number) => {
let asPtr = WasmI32.fromGrain(x)
if (isBoxedNumber(asPtr)) {
// Boxed numbers can have multiple subtypes, of which float32 and float64 can be NaN.
let tag = WasmI32.load(asPtr, 4n)
if (WasmI32.eq(tag, Tags._GRAIN_FLOAT64_BOXED_NUM_TAG)) {
// uses the fact that NaN is the only number not equal to itself
let wf64 = WasmF64.load(asPtr, 8n)
WasmF64.ne(wf64, wf64)
} else if (WasmI32.eq(tag, Tags._GRAIN_FLOAT32_BOXED_NUM_TAG)) {
let wf32 = WasmF32.load(asPtr, 8n)
WasmF32.ne(wf32, wf32)
} else {
// Neither rational numbers nor boxed integers can be infinite or NaN.
// Grain doesn't allow creating a rational with denominator of zero either.
false
}
} else {
// Simple numbers are integers and cannot be NaN.
false
}
isNaN(asPtr)
}

/**
Expand Down
3 changes: 1 addition & 2 deletions stdlib/number.md
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,7 @@ No other changes yet.
isNaN : Number -> Bool
```

Checks if a number contains the NaN value (Not A Number).
Only boxed floating point numbers can contain NaN.
Checks if a number is the float NaN value (Not A Number).

Parameters:

Expand Down
2 changes: 1 addition & 1 deletion stdlib/runtime/compare.gr
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ compareHelp = (x, y) => {
}
} else if (isNumber(x)) {
// Numbers have special comparison rules, e.g. NaN == NaN
tagSimpleNumber(numberCompare(x, y, true))
tagSimpleNumber(numberCompare(x, y))
} else {
// Handle all other heap allocated things
// Can short circuit if pointers are the same
Expand Down
94 changes: 64 additions & 30 deletions stdlib/runtime/numbers.gr
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ export let isRational = x => {
}
}

@unsafe
export let isNaN = x => {
if (isBoxedNumber(x)) {
// Boxed numbers can have multiple subtypes, of which float32 and float64 can be NaN.
let tag = WasmI32.load(x, 4n)
if (WasmI32.eq(tag, Tags._GRAIN_FLOAT64_BOXED_NUM_TAG)) {
// uses the fact that NaN is the only number not equal to itself
let wf64 = WasmF64.load(x, 8n)
WasmF64.ne(wf64, wf64)
} else if (WasmI32.eq(tag, Tags._GRAIN_FLOAT32_BOXED_NUM_TAG)) {
let wf32 = WasmF32.load(x, 8n)
WasmF32.ne(wf32, wf32)
} else {
// Neither rational numbers nor boxed integers can be infinite or NaN.
// Grain doesn't allow creating a rational with denominator of zero either.
false
}
} else {
// Simple numbers are integers and cannot be NaN.
false
}
}

@unsafe
let isBigInt = x => {
if (isBoxedNumber(x)) {
Expand Down Expand Up @@ -1777,8 +1800,11 @@ let cmpBigInt = (x: WasmI32, y: WasmI32) => {
}
}

// cmpFloat applies a total ordering relation:
// unlike regular float logic, NaN is considered equal to itself and
// smaller than any other number
@unsafe
let cmpFloat = (x: WasmI32, y: WasmI32, is64: Bool, totalOrdering: Bool) => {
let cmpFloat = (x: WasmI32, y: WasmI32, is64: Bool) => {
let xf = if (is64) {
boxedFloat64Number(x)
} else {
Expand All @@ -1787,13 +1813,13 @@ let cmpFloat = (x: WasmI32, y: WasmI32, is64: Bool, totalOrdering: Bool) => {
if (isSimpleNumber(y)) {
let yf = WasmF64.convertI32S(untagSimple(y))
// special NaN cases
if (totalOrdering && WasmF64.ne(xf, xf)) {
if (WasmF64.ne(xf, xf)) {
if (WasmF64.ne(yf, yf)) {
0n
} else {
-1n
}
} else if (totalOrdering && WasmF64.ne(yf, yf)) {
} else if (WasmF64.ne(yf, yf)) {
if (WasmF64.ne(xf, xf)) {
0n
} else {
Expand Down Expand Up @@ -1834,13 +1860,13 @@ let cmpFloat = (x: WasmI32, y: WasmI32, is64: Bool, totalOrdering: Bool) => {
},
}
// special NaN cases
if (totalOrdering && WasmF64.ne(xf, xf)) {
if (WasmF64.ne(xf, xf)) {
if (WasmF64.ne(yf, yf)) {
0n
} else {
-1n
}
} else if (totalOrdering && WasmF64.ne(yf, yf)) {
} else if (WasmF64.ne(yf, yf)) {
if (WasmF64.ne(xf, xf)) {
0n
} else {
Expand All @@ -1854,7 +1880,7 @@ let cmpFloat = (x: WasmI32, y: WasmI32, is64: Bool, totalOrdering: Bool) => {
}

@unsafe
let cmpSmallInt = (x: WasmI32, y: WasmI32, is64: Bool, totalOrdering: Bool) => {
let cmpSmallInt = (x: WasmI32, y: WasmI32, is64: Bool) => {
let xi = if (is64) {
boxedInt64Number(x)
} else {
Expand Down Expand Up @@ -1890,10 +1916,10 @@ let cmpSmallInt = (x: WasmI32, y: WasmI32, is64: Bool, totalOrdering: Bool) => {
) -1n else 1n
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT32_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpFloat(y, x, false, totalOrdering))
WasmI32.sub(0n, cmpFloat(y, x, false))
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT64_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpFloat(y, x, true, totalOrdering))
WasmI32.sub(0n, cmpFloat(y, x, true))
},
_ => {
throw UnknownNumberTag
Expand All @@ -1903,7 +1929,7 @@ let cmpSmallInt = (x: WasmI32, y: WasmI32, is64: Bool, totalOrdering: Bool) => {
}

@unsafe
let cmpRational = (x: WasmI32, y: WasmI32, totalOrdering: Bool) => {
let cmpRational = (x: WasmI32, y: WasmI32) => {
if (isSimpleNumber(y)) {
let xf = WasmF64.div(
BI.toFloat64(boxedRationalNumerator(x)),
Expand All @@ -1915,10 +1941,10 @@ let cmpRational = (x: WasmI32, y: WasmI32, totalOrdering: Bool) => {
let yBoxedNumberTag = boxedNumberTag(y)
match (yBoxedNumberTag) {
t when WasmI32.eq(t, Tags._GRAIN_INT32_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpSmallInt(y, x, false, totalOrdering))
WasmI32.sub(0n, cmpSmallInt(y, x, false))
},
t when WasmI32.eq(t, Tags._GRAIN_INT64_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpSmallInt(y, x, true, totalOrdering))
WasmI32.sub(0n, cmpSmallInt(y, x, true))
},
t when WasmI32.eq(t, Tags._GRAIN_BIGINT_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpBigInt(y, x))
Expand Down Expand Up @@ -1951,10 +1977,10 @@ let cmpRational = (x: WasmI32, y: WasmI32, totalOrdering: Bool) => {
}
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT32_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpFloat(y, x, false, totalOrdering))
WasmI32.sub(0n, cmpFloat(y, x, false))
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT64_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpFloat(y, x, true, totalOrdering))
WasmI32.sub(0n, cmpFloat(y, x, true))
},
_ => {
throw UnknownNumberTag
Expand All @@ -1964,30 +1990,31 @@ let cmpRational = (x: WasmI32, y: WasmI32, totalOrdering: Bool) => {
}

@unsafe
export let cmp = (x: WasmI32, y: WasmI32, totalOrdering: Bool) => {
export let cmp = (x: WasmI32, y: WasmI32) => {
if (isSimpleNumber(x)) {
if (isSimpleNumber(y)) {
if (WasmI32.ltS(x, y)) -1n else if (WasmI32.gtS(x, y)) 1n else 0n
// fast comparison path for simple numbers
WasmI32.sub(x, y)
} else {
let yBoxedNumberTag = boxedNumberTag(y)
match (yBoxedNumberTag) {
t when WasmI32.eq(t, Tags._GRAIN_INT32_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpSmallInt(y, x, false, totalOrdering))
WasmI32.sub(0n, cmpSmallInt(y, x, false))
},
t when WasmI32.eq(t, Tags._GRAIN_INT64_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpSmallInt(y, x, true, totalOrdering))
WasmI32.sub(0n, cmpSmallInt(y, x, true))
},
t when WasmI32.eq(t, Tags._GRAIN_BIGINT_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpBigInt(y, x))
},
t when WasmI32.eq(t, Tags._GRAIN_RATIONAL_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpRational(y, x, totalOrdering))
WasmI32.sub(0n, cmpRational(y, x))
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT32_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpFloat(y, x, false, totalOrdering))
WasmI32.sub(0n, cmpFloat(y, x, false))
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT64_BOXED_NUM_TAG) => {
WasmI32.sub(0n, cmpFloat(y, x, true, totalOrdering))
WasmI32.sub(0n, cmpFloat(y, x, true))
},
_ => {
throw UnknownNumberTag
Expand All @@ -1998,22 +2025,22 @@ export let cmp = (x: WasmI32, y: WasmI32, totalOrdering: Bool) => {
let xBoxedNumberTag = boxedNumberTag(x)
match (xBoxedNumberTag) {
t when WasmI32.eq(t, Tags._GRAIN_INT32_BOXED_NUM_TAG) => {
cmpSmallInt(x, y, false, totalOrdering)
cmpSmallInt(x, y, false)
},
t when WasmI32.eq(t, Tags._GRAIN_INT64_BOXED_NUM_TAG) => {
cmpSmallInt(x, y, true, totalOrdering)
cmpSmallInt(x, y, true)
},
t when WasmI32.eq(t, Tags._GRAIN_BIGINT_BOXED_NUM_TAG) => {
cmpBigInt(x, y)
},
t when WasmI32.eq(t, Tags._GRAIN_RATIONAL_BOXED_NUM_TAG) => {
cmpRational(x, y, totalOrdering)
cmpRational(x, y)
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT32_BOXED_NUM_TAG) => {
cmpFloat(x, y, false, totalOrdering)
cmpFloat(x, y, false)
},
t when WasmI32.eq(t, Tags._GRAIN_FLOAT64_BOXED_NUM_TAG) => {
cmpFloat(x, y, true, totalOrdering)
cmpFloat(x, y, true)
},
_ => {
throw UnknownNumberTag
Expand All @@ -2022,39 +2049,46 @@ export let cmp = (x: WasmI32, y: WasmI32, totalOrdering: Bool) => {
}
}

// In the comparison functions below, NaN is neither greater than, less than,
// or equal to any other number (including NaN), so any comparison involving
// NaN is always false. The only exception to this rule is `compare`, which
// applies a total ordering relation to allow numbers to be sortable (with
// NaN being considered equal to itself and less than all other numbers in
// this case).

@unsafe
export let (<) = (x: Number, y: Number) => {
let x = WasmI32.fromGrain(x)
let y = WasmI32.fromGrain(y)
WasmI32.ltS(cmp(x, y, false), 0n)
!isNaN(x) && !isNaN(y) && WasmI32.ltS(cmp(x, y), 0n)
}

@unsafe
export let (>) = (x: Number, y: Number) => {
let x = WasmI32.fromGrain(x)
let y = WasmI32.fromGrain(y)
WasmI32.gtS(cmp(x, y, false), 0n)
!isNaN(x) && !isNaN(y) && WasmI32.gtS(cmp(x, y), 0n)
}

@unsafe
export let (<=) = (x: Number, y: Number) => {
let x = WasmI32.fromGrain(x)
let y = WasmI32.fromGrain(y)
WasmI32.leS(cmp(x, y, false), 0n)
!isNaN(x) && !isNaN(y) && WasmI32.leS(cmp(x, y), 0n)
}

@unsafe
export let (>=) = (x: Number, y: Number) => {
let x = WasmI32.fromGrain(x)
let y = WasmI32.fromGrain(y)
WasmI32.geS(cmp(x, y, false), 0n)
!isNaN(x) && !isNaN(y) && WasmI32.geS(cmp(x, y), 0n)
}

@unsafe
export let compare = (x: Number, y: Number) => {
let x = WasmI32.fromGrain(x)
let y = WasmI32.fromGrain(y)
WasmI32.toGrain(tagSimple(cmp(x, y, true))): Number
WasmI32.toGrain(tagSimple(cmp(x, y))): Number
}

/*
Expand Down
8 changes: 7 additions & 1 deletion stdlib/runtime/numbers.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ isInteger : WasmI32 -> Bool
isRational : WasmI32 -> Bool
```

### Numbers.**isNaN**

```grain
isNaN : WasmI32 -> Bool
```

### Numbers.**isNumber**

```grain
Expand Down Expand Up @@ -109,7 +115,7 @@ numberEqual : (WasmI32, WasmI32) -> Bool
### Numbers.**cmp**

```grain
cmp : (WasmI32, WasmI32, Bool) -> WasmI32
cmp : (WasmI32, WasmI32) -> WasmI32
```

### Numbers.**(<)**
Expand Down

0 comments on commit f7ceae7

Please sign in to comment.