Skip to content

Commit

Permalink
Merge branch 'main' into typing_and_simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
kvkenyon authored Jul 16, 2024
2 parents 149529d + dedc782 commit aca5dc3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
19 changes: 15 additions & 4 deletions shrimpgrad/view.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
import functools
from itertools import accumulate
import itertools
import operator
from typing import List, Optional, Tuple
Expand All @@ -21,7 +20,7 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
if not shape: return ()
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
return normalize_strides(shape, strides)

class ViewTracker:
def __init__(self, views: List[View]):
self.views: List[View] = views
Expand Down Expand Up @@ -78,8 +77,20 @@ def create_view(shape: Tuple[int,...],
if 0 in shape: return View(shape, (0,)*len(shape))
return View(shape, normalize_strides(shape, strides) if strides is not None else strides)

def create_view(shape: Tuple[int,...],
strides: Optional[Tuple[int,...]]=None,
mask: Optional[Tuple[Tuple[int,int],...]]=None,
offset:int=0):

# standardize 0 in shape
if 0 in shape: return View(shape, (0,)*len(shape))
# standardize empty mask to None
if mask is not None and all((s==0 and e == dim_size for ((s,e), dim_size) in zip(mask, shape))): mask = None

return View(shape, normalize_strides(shape, strides) if strides is not None else strides, mask, offset)

class View:
"""A description of how a thunk's data is interpreted
"""The layout for the thunk
"""
def __init__(self, shape: Tuple[int,...],
strides: Optional[Tuple[int,...]]=None):
Expand Down Expand Up @@ -174,4 +185,4 @@ def shrink(self, arg: Tuple[Tuple[int, int],...]) -> View:
@staticmethod
def from_view(view: View): return create_view(view.shape, view.strides)

def __repr__(self): return f'<View shape={self.shape} strides={self.strides} contig={self.contiguous}>'
def __repr__(self): return f'<View shape={self.shape} strides={self.strides} contig={self.contiguous} mask={self.mask}>'
24 changes: 20 additions & 4 deletions test/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,20 @@ def test_view(self):
v = View(())
self.assertTrue(v.scalar)

def test_reshape_permute_reshape_expand(self):
def test_reshape(self):
vt = ViewTracker.from_shape((10,4))
self.assertEqual((10,4), vt.shape)
vt = vt.reshape((40,1))
self.assertEqual((40,1), vt.shape)

def test_reshape_like_permute(self):
vt = ViewTracker.from_shape((2,4))
self.assertEqual((2,4), vt.shape)
vt = vt.reshape((4,2))
self.assertEqual((4,2), vt.shape)
self.assertTrue(vt.contiguous)

def test_reshape_permute_reshape(self):
vt = ViewTracker.from_shape((2,2))
assert vt.shape == (2,2)
assert vt.strides == (2,1)
Expand Down Expand Up @@ -40,6 +53,7 @@ def test_pad2(self):
assert vt.shape == (5, 10, 10)
assert vt.numel == 5*10*10
assert vt.strides == (100,10,1)

assert vt.ndim == 3
assert len(vt.views) == 1

Expand All @@ -55,14 +69,12 @@ def test_permute_pad(self):
def test_shrink(self):
view = View((2,2,2))
view = view.shrink(((0,1), (0,1), (0,0)))

assert view.shape == (1,1,0)
assert view.strides == (0,0,0)

def test_shrink1(self):
vt = ViewTracker.from_shape((2,4,2))
vt = vt.shrink(((0,1),(1,3),(0,2)))

assert vt.view.shape == (1,2,2)
assert vt.view.strides == (0,2,1)

Expand All @@ -73,13 +85,15 @@ def test_undo_pad_with_shrink_and_mask(self):
vt = vt.shrink(((2,6), (0,7), (1,5)))
self.assertEqual((4,7,4), vt.shape)


def test_pad_then_shrink_a_bit(self):
vt = ViewTracker.from_shape((4,))
vt = vt.pad(((2,2),))
self.assertEqual((8,), vt.shape)
vt = vt.shrink(((1,8),))
self.assertEqual((7,), vt.shape)


def test_pad_then_shrink_into_outer_pad(self):
vt = ViewTracker.from_shape((4,))
vt = vt.pad(((2,2),))
Expand All @@ -94,16 +108,18 @@ def test_shrink_pad_back(self):
vt = vt.pad(((1,1),))
self.assertEqual((4,),vt.shape)


def test_pad_reshape_adjusts_mask(self):
vt = ViewTracker.from_shape((2,2))
vt = vt.pad(((1,1),(1,1)))
self.assertEqual((4,4), vt.shape)
vt = vt.reshape((1,4,4))
self.assertEqual((1,4,4), vt.shape)


def test_pad_reshape_adjust_mask2(self):
vt = ViewTracker.from_shape((8,2))
vt = vt.pad(((1,1),(1,1)))
self.assertEqual((10,4), vt.shape)
vt = vt.reshape((40,1))
self.assertEqual((40,1), vt.shape)
self.assertEqual((40,1), vt.shape)

0 comments on commit aca5dc3

Please sign in to comment.