Skip to content

Commit

Permalink
Add uint8 dtype support and mnist fetching
Browse files Browse the repository at this point in the history
  • Loading branch information
kvkenyon committed Aug 28, 2024
1 parent 3e1c1ed commit 45dbfd0
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 3 deletions.
3 changes: 2 additions & 1 deletion shrimpgrad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from shrimpgrad.dtype import *
from shrimpgrad.device import *
from shrimpgrad.engine.graph import *
from shrimpgrad.knobs import *
from shrimpgrad.knobs import *
from shrimpgrad.nn import *
7 changes: 6 additions & 1 deletion shrimpgrad/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ def __repr__(self):

class dtypes:
int32: Final[DType] = DType(4, "int32")
uint8: Final[DType] = DType(1, "uint8")
float32: Final[DType] = DType(4, "float32")
bool_: Final[DType] = DType(1, "bool")
@staticmethod
def from_py(x: Union[float,int,bool]) -> DType:
t = type(x)
if t == float: return dtypes.float32
if t == int: return dtypes.int32
return dtypes.bool_
if t == bool: return dtypes.bool_
if t == np.uint8: return dtypes.uint8
raise TypeError(f"dtype {t} is not supported.")

@staticmethod
def cast(dtype: DType, x: ConstType) -> ConstType:
Expand All @@ -36,10 +39,12 @@ def to_numpy(dtype: DType) :
if dtype == dtypes.float32: return np.float32
if dtype == dtypes.int32: return np.int32
if dtype == dtypes.bool_: return np.bool_
if dtype == dtypes.uint8: return np.uint8
raise TypeError(f"dtype {dtype} is not supported.")

def to_ctype(dtype: DType):
if dtype == dtypes.float32: return ctypes.c_float
if dtype == dtypes.int32: return ctypes.c_int
if dtype == dtypes.bool_: return ctypes.c_bool
if dtype == dtypes.uint8: return ctypes.c_ubyte
raise TypeError(f"dtype {dtype} is not supported.")
2 changes: 1 addition & 1 deletion shrimpgrad/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def load_from_cpu(data, dtype, shape):
thunk.buff.allocate(with_data=data)
del thunk._operands
return thunk

@staticmethod
def loadop(op: LoadOps, shape, dtype, device, arg=None, srcs=()):
return create_thunk(device, dtype, ViewTracker.from_shape(shape), srcs, op=op, arg=arg)
Expand Down
30 changes: 30 additions & 0 deletions shrimpgrad/nn/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import shutil
import urllib.request
import gzip
from shrimpgrad import Tensor, dtypes

url_training = "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz"
url_test = "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz"
url_training_labels = "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz"
url_test_labels = "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz"

_tmp_dir = "/tmp/mnist"
def fetch(url):
if not os.path.exists(_tmp_dir): os.makedirs(_tmp_dir)
if not os.path.exists(os.path.join(_tmp_dir, os.path.basename(url))):
with urllib.request.urlopen(url) as r:
filename = os.path.join(_tmp_dir, os.path.basename(url))
with open(filename, 'wb') as f: shutil.copyfileobj(r, f)
return filename
print(f"Using cached file {os.path.join(_tmp_dir, os.path.basename(url))}")
return os.path.join(_tmp_dir, os.path.basename(url))

def mnist():
def load_data(url, offset): return gzip.open(fetch(url)).read()[offset:]
return (
Tensor.frombytes((60000, 1, 28, 28), load_data(url_training, 16), dtypes.uint8),
Tensor.frombytes((10000,1, 28, 28), load_data(url_test, 16), dtypes.uint8),
Tensor.frombytes((60000,), load_data(url_training_labels, 8), dtypes.uint8),
Tensor.frombytes((10000,), load_data(url_test_labels, 8), dtypes.uint8)
)
119 changes: 119 additions & 0 deletions shrimpgrad/nn/mnist_fetch.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions shrimpgrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(self, shape: Shape, data: Union[List[ConstType], bytes, ConstType,
self.ctx: Optional[Function] = None
self.cls: Optional[Type[Function]] = None
if isinstance(data, Thunk): self.thunk = data
elif isinstance(data, bytes):
self.thunk = Thunk.load_from_cpu(np.frombuffer(data, dtype=np.uint8), dtype, shape)
elif isinstance(data, ConstType): self.thunk = Thunk.load_const(data, shape, dtype, device)
else:
if shape == () and not isinstance(data, ConstType) and len(data) == 1: data = data[0]
Expand Down Expand Up @@ -456,6 +458,9 @@ def arange(start: int, stop:int, step:int=1, dtype:DType=dtypes.float32, **kwarg
@staticmethod
def fromlist(shape: Shape, data:List[ConstType], dtype=dtypes.float32, **kwargs): return Tensor(shape, data=data, dtype=dtype, **kwargs)

@staticmethod
def frombytes(shape: Shape, data: bytes, dtype=dtypes.float32, **kwargs): return Tensor(shape, data, dtype)

@staticmethod
def full(shape: Shape, fill_value: ConstType, dtype=dtypes.float32, **kwargs) -> Tensor:
if not len(shape): return Tensor((), fill_value)
Expand Down

0 comments on commit 45dbfd0

Please sign in to comment.