Skip to content

Commit

Permalink
test dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 24, 2023
1 parent 68c512a commit 24ced37
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 26 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ jobs:
run: PYTHONPATH="." python mnist.py
- name: Install torch for testing
run: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test Ops
run: PYTHONPATH="." python test/test_ops.py
- name: Test Ops and DTypes
run: |
PYTHONPATH="." python test/test_ops.py
PYTHONPATH="." python test/test_dtypes.py
2 changes: 1 addition & 1 deletion import_from_tinygrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
import pathlib

FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py"]
FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py"]
src = pathlib.Path("../tinygrad/tinygrad")
dest = pathlib.Path("teenygrad")

Expand Down
2 changes: 1 addition & 1 deletion mnist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
import numpy as np
from teenygrad.tensor import Tensor
from teenygrad import Tensor
from tqdm import trange
import gzip, os

Expand Down
4 changes: 3 additions & 1 deletion sz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@
for dir_name, group in itertools.groupby(sorted([(x[0].rsplit("/", 1)[0], x[1]) for x in table]), key=lambda x:x[0]):
print(f"{dir_name:30s} : {sum([x[1] for x in group]):6d}")

print(f"\ntotal line count: {sum([x[1] for x in table])}")
total_line_count = sum([x[1] for x in table])
print(f"\ntotal line count: {total_line_count}")
assert total_line_count < 1000, "TEENYGRAD IS FUCKING TEENY IF YOU GO OVER 1000 LINES IN TEENYGRAD MIGHT AS WELL USE TINYGRAD U FAT FUCK"
1 change: 1 addition & 0 deletions teenygrad/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from teenygrad.tensor import Tensor # noqa: F401
30 changes: 24 additions & 6 deletions teenygrad/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Union, Tuple, Iterator, NamedTuple, Optional, Final, Any
import os, functools
from typing import Union, Tuple, Iterator, Optional, Final, Any
import os, functools, platform
import numpy as np
from math import prod # noqa: F401 # pylint:disable=unused-import
from dataclasses import dataclass

OSX = platform.system() == "Darwin"
def dedup(x): return list(dict.fromkeys(x)) # retains list orderi
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
Expand All @@ -28,16 +29,33 @@ class DType:
def __repr__(self): return f"dtypes.{self.name}"

class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float32, dtypes.float64)
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(9, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(10, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(11, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(1, 1, "char", np.int8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)

# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)

DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}

ImageDType, IMAGE = None, 0 # junk to remove
PtrDType, ImageDType, IMAGE = None, None, 0 # junk to remove
34 changes: 19 additions & 15 deletions teenygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from teenygrad.helpers import DType, dtypes
from teenygrad.helpers import DType, dtypes, DEBUG
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps
import numpy as np

Expand All @@ -18,6 +18,7 @@ def dtype(self): return dtypes.from_np(self._np.dtype)
def realized(self): return RawCPUBuffer(self._np)
@property
def shape(self): return self._np.shape
def __repr__(self): return f"<LB {self.shape} {self.dtype}>"

def schedule(self, seen=None): return []
def is_unrealized_const(self): return False
Expand All @@ -35,27 +36,30 @@ def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
def contiguous(x): return x
def const(self, x) -> LazyBuffer: return LazyBuffer(np.full_like(self._np, x))

def cast(self, dtype:DType, bitcast:bool=False): return LazyBuffer(self._np.astype(dtype.np))
def cast(self, dtype:DType, bitcast:bool=False): return LazyBuffer(self._np.view(dtype.np) if bitcast else self._np.astype(dtype.np))

def e(self, op, *srcs:LazyBuffer):
if op == UnaryOps.NEG: return LazyBuffer(-self._np)
elif op == UnaryOps.EXP2: return LazyBuffer(np.exp2(self._np))
elif op == UnaryOps.LOG2: return LazyBuffer(np.log2(self._np))
elif op == UnaryOps.SIN: return LazyBuffer(np.sin(self._np))
elif op == UnaryOps.SQRT: return LazyBuffer(np.sqrt(self._np))
elif op == BinaryOps.ADD: return LazyBuffer(self._np + srcs[0]._np)
elif op == BinaryOps.SUB: return LazyBuffer(self._np - srcs[0]._np)
elif op == BinaryOps.MUL: return LazyBuffer(self._np * srcs[0]._np)
elif op == BinaryOps.DIV: return LazyBuffer((self._np / srcs[0]._np).astype(max(self.dtype, srcs[0].dtype).np))
elif op == BinaryOps.MAX: return LazyBuffer(np.maximum(self._np, srcs[0]._np))
elif op == BinaryOps.CMPLT: return LazyBuffer(self._np < srcs[0]._np)
elif op == TernaryOps.WHERE: return LazyBuffer(np.where(self._np, srcs[0]._np, srcs[1]._np))
if DEBUG >= 1: print(op, self, srcs)
if op == UnaryOps.NEG: ret = -self._np
elif op == UnaryOps.EXP2: ret = np.exp2(self._np)
elif op == UnaryOps.LOG2: ret = np.log2(self._np)
elif op == UnaryOps.SIN: ret = np.sin(self._np)
elif op == UnaryOps.SQRT: ret = np.sqrt(self._np)
elif op == BinaryOps.ADD: ret = self._np + srcs[0]._np
elif op == BinaryOps.SUB: ret = self._np - srcs[0]._np
elif op == BinaryOps.MUL: ret = self._np * srcs[0]._np
elif op == BinaryOps.DIV: ret = self._np / srcs[0]._np
elif op == BinaryOps.MAX: ret = np.maximum(self._np, srcs[0]._np)
elif op == BinaryOps.CMPLT: ret = self._np < srcs[0]._np
elif op == TernaryOps.WHERE: ret = np.where(self._np, srcs[0]._np, srcs[1]._np)
else: raise NotImplementedError(op)
return LazyBuffer(ret.astype(self.dtype.np if len(srcs) == 0 else max(self.dtype, *[x.dtype for x in srcs]).np, copy=False))

def r(self, op, new_shape):
if DEBUG >= 1: print(op, self, new_shape)
assert len(self.shape) == len(new_shape), "reduce shapes must have same dimensions"
axis = tuple(i for i,(a,b) in enumerate(zip(self.shape, new_shape)) if a != b)
if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, keepdims=True))
if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, dtype=self._np.dtype, keepdims=True))
elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(axis, keepdims=True))
else: raise NotImplementedError(op)

Expand Down
194 changes: 194 additions & 0 deletions test/test_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import unittest
import numpy as np
from teenygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX
from teenygrad.ops import Device
from teenygrad.tensor import Tensor, dtypes
from typing import Any, List

def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU")
if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"]
if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"]
if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"]
if dtype == dtypes.bool:
# host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable
if Device.DEFAULT == "WEBGPU": return False
return True

def get_available_cast_dtypes(dtype: DType) -> List[DType]: return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes

def _test_to_np(a:Tensor, np_dtype, target):
if DEBUG >= 2: print(a)
na = a.numpy()
if DEBUG >= 2: print(na, na.dtype, a.lazydata.realized)
try:
assert na.dtype == np_dtype
np.testing.assert_allclose(na, target)
except AssertionError as e:
raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e

def _assert_eq(tensor:Tensor, target_dtype:DType, target):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target)
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e

def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist())
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)

class TestDType(unittest.TestCase):
DTYPE: Any = None
DATA: Any = None
@classmethod
def setUpClass(cls):
if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")

def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np))

def test_casts_to(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
get_available_cast_dtypes(self.DTYPE)
))
def test_casts_from(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
get_available_cast_dtypes(self.DTYPE)
))

def test_same_size_ops(self):
def get_target_dtype(dtype):
if any([dtypes.is_float(dtype), dtypes.is_float(self.DTYPE)]): return max([dtype, self.DTYPE], key=lambda x: x.priority)
return dtype if dtypes.is_unsigned(dtype) else self.DTYPE
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype, target_dtype=get_target_dtype(dtype)) if dtype.itemsize == self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_upcast_ops(self): list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_upcast_to_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))

def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return
if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return
target_dtype = target_dtype or (max([a_dtype, b_dtype], key=lambda x: x.priority) if a_dtype.priority != b_dtype.priority else max([a_dtype, b_dtype], key=lambda x: x.itemsize))
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())

class TestBFloat16DType(unittest.TestCase):
def setUp(self):
if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported")
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])

def test_float_to_bf16(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000])

# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)

def test_bf16(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16)
t.realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)

def test_bf16_disk_write_read(self):
from extra.utils import temp
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()

# hack to "cast" f32 -> bf16
dat = open(temp('f32'), "rb").read()
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)

t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)

class TestHalfDtype(TestDType): DTYPE = dtypes.half

class TestFloatDType(TestDType): DTYPE = dtypes.float

class TestDoubleDtype(TestDType): DTYPE = dtypes.double

class TestInt8Dtype(TestDType):
DTYPE = dtypes.int8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252])

class TestUint8Dtype(TestDType):
DTYPE = dtypes.uint8
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])

@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
class TestBitCast(unittest.TestCase):
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])

# NOTE: these are the same as normal casts
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])

def test_shape_change_bitcast(self):
with self.assertRaises(AssertionError):
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])

class TestInt16Dtype(TestDType): DTYPE = dtypes.int16
class TestUint16Dtype(TestDType): DTYPE = dtypes.uint16

class TestInt32Dtype(TestDType): DTYPE = dtypes.int32
class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32

class TestInt64Dtype(TestDType): DTYPE = dtypes.int64
class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64

class TestBoolDtype(TestDType): DTYPE = dtypes.bool

class TestEqStrDType(unittest.TestCase):
def test_image_ne(self):
if ImageDType is None: raise unittest.SkipTest("no ImageDType support")
assert dtypes.float == dtypes.float32, "float doesn't match?"
assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match"
assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match"
assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches"
assert isinstance(dtypes.imageh((1,2,4)), ImageDType)
def test_ptr_ne(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
# TODO: is this the wrong behavior?
assert PtrDType(dtypes.float32) == dtypes.float32
#assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
#assert PtrDType(dtypes.float32) != dtypes.float32
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")

if __name__ == '__main__':
unittest.main()

0 comments on commit 24ced37

Please sign in to comment.