Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce IEEE P3109 dtypes #122

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: debug-statements
- repo: https://github.com/google/pyink
rev: 23.3.1
rev: 23.10.0
hooks:
- id: pyink
language_version: python3.9
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* `float8_e4m3fnuz`
* `float8_e5m2`
* `float8_e5m2fnuz`
* `float8_p3109_p<p>`
- `int4` and `uint4`: low precision integer types.

See below for specifications of these number formats.
Expand Down Expand Up @@ -107,6 +108,20 @@ This type has the following characteristics:
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000`
* denormals when exponent is 0

### float8_p3109_p<p>

These types represent the types under discussion in IEEE working group P3109,
"Arithmetic Formats for Machine Learning ", parameterized by precision $p$.

These type has the following characteristics:
* Precision $p$: $2 < p < 6$
* Exponent bits, E: $8-p$
* Exponent bias: 2 ^ (E-1)
* Infinities: +Inf, -Inf
* No negative zero
* Single NaN in the -0 position: `0b10000000` == `0x80`
* Denormals when exponent is 0

## `int4` and `uint4`

4-bit integer types, where each element is represented unpacked (i.e., padded up
Expand Down
33 changes: 21 additions & 12 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '0.3.1' # Keep in sync with pyproject.toml:version
__version__ = "0.3.1" # Keep in sync with pyproject.toml:version
__all__ = [
'__version__',
'bfloat16',
'finfo',
'float8_e4m3b11fnuz',
'float8_e4m3fn',
'float8_e4m3fnuz',
'float8_e5m2',
'float8_e5m2fnuz',
'iinfo',
'int4',
'uint4',
"__version__",
"bfloat16",
"finfo",
"float8_e4m3b11fnuz",
"float8_e4m3fn",
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_p3109_p3",
"float8_p3109_p4",
"float8_p3109_p5",
"iinfo",
"int4",
"uint4",
]

from typing import Type
Expand All @@ -37,6 +40,9 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_p3109_p3
from ml_dtypes._ml_dtypes_ext import float8_p3109_p4
from ml_dtypes._ml_dtypes_ext import float8_p3109_p5
from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint4
import numpy as np
Expand All @@ -47,6 +53,9 @@
float8_e4m3fnuz: Type[np.generic]
float8_e5m2: Type[np.generic]
float8_e5m2fnuz: Type[np.generic]
float8_p3109_p3: Type[np.generic]
float8_p3109_p4: Type[np.generic]
float8_p3109_p5: Type[np.generic]
int4: Type[np.generic]
uint4: Type[np.generic]

Expand Down
148 changes: 148 additions & 0 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_p3109_p3
from ml_dtypes._ml_dtypes_ext import float8_p3109_p4
from ml_dtypes._ml_dtypes_ext import float8_p3109_p5

import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand All @@ -30,6 +34,9 @@
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz)
_float8_p3109_p3_dtype = np.dtype(float8_p3109_p3)
_float8_p3109_p4_dtype = np.dtype(float8_p3109_p4)
_float8_p3109_p5_dtype = np.dtype(float8_p3109_p5)


class _Bfloat16MachArLike:
Expand Down Expand Up @@ -86,6 +93,29 @@ def __init__(self):
self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal)


class _Float8IEEEMachArLike:

def __init__(self, p):
# These are hard-coded in order to independently test against the computed values in the C++ implementation
if p == 3:
smallest_normal = float.fromhex("0x1p-15")
self.smallest_normal = float8_p3109_p3(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-17")
self.smallest_subnormal = float8_p3109_p3(smallest_subnormal)

if p == 4:
smallest_normal = float.fromhex("0x1p-7")
self.smallest_normal = float8_p3109_p4(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-10")
self.smallest_subnormal = float8_p3109_p4(smallest_subnormal)

if p == 5:
smallest_normal = float.fromhex("0x1p-3")
self.smallest_normal = float8_p3109_p5(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-7")
self.smallest_subnormal = float8_p3109_p5(smallest_subnormal)


class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
Expand Down Expand Up @@ -360,6 +390,114 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_p3109_p_finfo(p):
def float_to_str(f):
return "%6.2e" % float(f)

# pylint: disable=protected-access
obj = object.__new__(np.finfo)

if p == 3:
dtype = float8_p3109_p3
obj.dtype = _float8_p3109_p3_dtype
elif p == 4:
dtype = float8_p3109_p4
obj.dtype = _float8_p3109_p4_dtype
elif p == 5:
dtype = float8_p3109_p5
obj.dtype = _float8_p3109_p5_dtype
else:
raise NotImplementedError()

obj._machar = _Float8IEEEMachArLike(p)

bias = 2 ** (7 - p)
tiny = obj._machar.smallest_normal
machep = 1 - p
eps = 2.0**machep
negep = -p
epsneg = 2.0**negep
max_ = (1 - 2 ** (1 - p)) * 2**bias # 1'0000 - 0'0010 = 0'1110

if p == 3:
assert tiny == float.fromhex("0x1p-15")
assert eps == float.fromhex("0x1p-2")
assert epsneg == float.fromhex("0x1p-3")
assert max_ == float.fromhex("0x1.8p15")
elif p == 4:
assert tiny == float.fromhex("0x1p-7")
assert eps == float.fromhex("0x1p-3")
assert epsneg == float.fromhex("0x1p-4")
assert max_ == float.fromhex("0x1.Cp7")
elif p == 5:
assert tiny == float.fromhex("0x1p-3")
assert eps == float.fromhex("0x1p-4")
assert epsneg == float.fromhex("0x1p-5")
assert max_ == float.fromhex("0x1.Ep3")
else:
raise NotImplementedError()

obj.bits = 8

# nextafter(1.0, Inf) - 1.0
obj.eps = dtype(eps)

# The exponent that yields eps.
obj.machep = machep

# 1.0 = nextafter(1.0, -Inf)
obj.epsneg = dtype(epsneg)

# The exponent that yields epsneg.
obj.negep = negep

# The largest representable number.
obj.max = dtype(max_)

# The smallest representable number, typically -max.
obj.min = dtype(-max_)

obj.nexp = 8 - p
obj.nmant = p - 1
obj.iexp = obj.nexp
obj.maxexp = bias
obj.minexp = 1 - bias

# The approximate number of decimal digits to which this kind of float is precise.
obj.precision = 1 if p < 4 else 2

# The approximate decimal resolution of this type, i.e., 10**-precision.
obj.resolution = dtype(10**-obj.precision)

if not hasattr(obj, "tiny"):
obj.tiny = dtype(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_p3109_p3_finfo():
return finfo._float8_p3109_p_finfo(3)

@staticmethod
def _float8_p3109_p4_finfo():
return finfo._float8_p3109_p_finfo(4)

@staticmethod
def _float8_p3109_p5_finfo():
return finfo._float8_p3109_p_finfo(5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems strange for a static method to refer to its class by name...

Maybe we could remove these three and use cls._float8_p3109_p_finfo(p) directly in __new__ below

Copy link
Author

@awf awf Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems strange for a static method to refer to its class by name...

Agreed

Maybe we could remove these three and use cls._float8_p3109_p_finfo(p) directly in __new__ below

Done. As an aside I put all the tests into the same for loop, makes the code rather tidier (and no measurable speed impact in pytest), hope that's reasonable.


def __new__(cls, dtype):
if (
isinstance(dtype, str)
Expand Down Expand Up @@ -411,4 +549,14 @@ def __new__(cls, dtype):
if _float8_e5m2fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo()
return cls._finfo_cache[_float8_e5m2fnuz_dtype]
for type_str, test_dtype, finfo in (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's confusing that the local finfo variable shadows the finfo class name.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree, apologies.

("float8_p3109_p3", _float8_p3109_p3_dtype, cls._float8_p3109_p3_finfo),
("float8_p3109_p4", _float8_p3109_p4_dtype, cls._float8_p3109_p4_finfo),
("float8_p3109_p5", _float8_p3109_p5_dtype, cls._float8_p3109_p5_finfo),
):
if isinstance(dtype, str) and dtype == type_str or dtype == test_dtype:
if test_dtype not in cls._finfo_cache:
cls._finfo_cache[test_dtype] = finfo()
return cls._finfo_cache[test_dtype]

return super().__new__(cls, dtype)
Loading