Skip to content

Commit

Permalink
Add mask reshaping on reshapes
Browse files Browse the repository at this point in the history
  • Loading branch information
kvkenyon committed Jul 23, 2024
1 parent 27a3737 commit 5fa3800
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
57 changes: 53 additions & 4 deletions shrimpgrad/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,18 @@ def numel(self) -> int: return prod(self.shape)
def ndim(self): return len(self.shape)

def reshape(self, new_shape: Tuple[int,...]) -> View:
new_mask = None
if self.mask is not None:
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
if len(self.shape):
assert prod(new_shape) == self.numel, f'shape \'{new_shape}\' is invalid for input of size {self.numel} of shape {self.shape}'
# Fast path (new strides are easy to compute)
if self.contiguous: return create_view(new_shape)
# Slow path (reconstruct the new strides without copying)
new_strides = self._attempt_no_copy_reshape(new_shape)
assert new_strides is not None, "failed to reshape"
return create_view(new_shape, tuple(new_strides))
return create_view(new_shape)
if new_strides is None: return create_view(new_shape, mask=new_mask)
return create_view(new_shape, tuple(new_strides), new_mask)
return create_view(new_shape, mask=new_mask)

def _attempt_no_copy_reshape(self, new_shape):
# Remove ones from the old shape
Expand Down Expand Up @@ -190,4 +193,50 @@ def shrink(self, arg: Tuple[Tuple[int, int],...]) -> View:
@staticmethod
def from_view(view: View): return create_view(view.shape, view.strides)

def __repr__(self): return f'<View shape={self.shape} strides={self.strides} contig={self.contiguous}>'
def __repr__(self): return f'<View shape={self.shape} strides={self.strides} contig={self.contiguous} mask={self.mask} offset={self.offset}>'

# TODO: Figure this out more clearly so we can rewrite this (this is the tinygrad implementation)
def _reshape_mask(_mask, old_shape, new_shape):
if _mask is None: return tuple((0, s) for s in new_shape)
if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask

new_mask = []
# _mask is all int here
r_masks, r_shape, r_new_shape = reversed(_mask), reversed(old_shape), reversed(new_shape)
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))

while len(new_mask) < len(new_shape):
(l, r), next_stride = mask, new_dim * curr_stride
print(f"{l = } {r = } {next_stride = } {old_dim = } {new_dim = } {mask = }")

if old_dim >= next_stride: # need to split mask.
print("Split mask")
if old_dim == next_stride: # simply copy the mask and get next batch for merging
print("Copy")
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))

else: # mask can only be splitted if reshape doesn't cut across the mask.
print("Check cut across")
print(f"{l % next_stride = } {r % next_stride = } {l // next_stride = }")

if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
or old_dim % next_stride != 0):
print("CUTTED")
return None
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension

else:
next_mask = next(r_masks, (0, 1))
print(f"No split needed {next_mask = }")
# combine if the mask can unfold continuously
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
print(f"{new_mask = }")

for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask

return tuple(reversed(new_mask))
15 changes: 9 additions & 6 deletions test/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,15 @@ def test_pad_reshape_adjusts_mask(self):
vt = vt.pad(((1,1),(1,1)))
self.assertEqual((4,4), vt.shape)
vt = vt.reshape((1,4,4))
self.assertEqual(((0, 1), (1, 3), (1, 3)), vt.view.mask)
self.assertEqual((1,4,4), vt.shape)


# def test_pad_reshape_adjust_mask2(self):
# vt = ViewTracker.from_shape((8,2))
# vt = vt.pad(((1,1),(1,1)))
# self.assertEqual((10,4), vt.shape)
# vt = vt.reshape((40,1))
# self.assertEqual((40,1), vt.shape)
def test_pad_reshape_adjust_mask2(self):
vt = ViewTracker.from_shape((8,2))
vt = vt.pad(((1,1),(1,1)))
self.assertEqual((10,4), vt.shape)
vt = vt.reshape((40,1))
self.assertIsNone(vt.view.mask)
self.assertEqual((40,1), vt.shape)
print(vt)

0 comments on commit 5fa3800

Please sign in to comment.