Skip to content

Commit

Permalink
Add symbolic expression rendering to views
Browse files Browse the repository at this point in the history
  • Loading branch information
kvkenyon committed Jul 24, 2024
1 parent 0cd8567 commit aa2cfbb
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 19 deletions.
5 changes: 3 additions & 2 deletions shrimpgrad/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def show(self):

class Expr:
def __add__(self, other: Expr): return Bin(ArithOp.PLUS, self, other)
def __iadd__(self, other: Expr): return self + other
def __radd__(self, other): return self + other
def __mul__(self, other: Expr): return Bin(ArithOp.MUL, self, other)
def __rmul__(self, other: Expr): return self * other
Expand All @@ -46,9 +47,9 @@ def __rtruediv__(self, other): return self /other
def __mod__(self, other: Expr): return Bin(ArithOp.MOD, self, other)
def __rmod__(self, other: Expr): return self % other
def __lt__(self, other: Expr): return Bin(ArithOp.LT, self, other)
def __lte__(self, other: Expr): return Bin(ArithOp.LTE, self, other)
def __le__(self, other: Expr): return Bin(ArithOp.LTE, self, other)
def __gt__(self, other: Expr): return Bin(ArithOp.GT, self, other)
def __gte__(self, other: Expr): return Bin(ArithOp.GTE, self, other)
def __ge__(self, other: Expr): return Bin(ArithOp.GTE, self, other)
def __eq__(self, other: Expr): return Bin(ArithOp.EQ, self, other)
def __neg__(self): return Unary(ArithOp.NEG, self)
def ifelse(self, x, y): return IfElse(self, x , y)
Expand Down
62 changes: 49 additions & 13 deletions shrimpgrad/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import operator
from typing import List, Optional, Tuple
from shrimpgrad.symbolic import Expr, Interval, Lit, Symbol, render
from shrimpgrad.util import prod

def can_merge_axes(shape: Tuple[int,...], strides: Tuple[int,...], start:int, stop:int):
Expand Down Expand Up @@ -55,6 +56,16 @@ def pad(self, pad_width: Tuple[Tuple[int,int], ...]) -> ViewTracker:

def shrink(self, arg: Tuple[Tuple[int,int], ...]) -> ViewTracker:
return ViewTracker.from_views(self.views + [self.view.shrink(arg)])

def to_symbolic(self) -> Tuple[Expr, Expr]:
""" Create a symbolic expression for the final view.
"""
return self.view.to_symbolic()

def render(self) -> Tuple[str, str]:
""" Render strings for the symbolic expression for the final view.
"""
return self.view.render()

@staticmethod
def from_views(views: List[View]) -> ViewTracker:
Expand All @@ -77,19 +88,8 @@ def create_view(shape: Tuple[int,...],

class View:
"""
The view of a thunk's underlying data buffer.
TODO:
Something that defines slices within the buffer that have actual backing data
- After pad and shrink we have virtual expansion/contraction of the dimension and we want
to keep things zero copy i.e.) on pad don't copy the buffer to a new location and fill in zeros where
they are needed allocating a bunch of memory for the padding that's unecessary.
- Instead pretend we have been padded, and at realize codegen if values fall out of the valid range use the padded value
- In this way we keep the original size but don't have to actually store all the padded intermediate tensors
- An offset used for computing loop indices, where values that fall outside of the valid range are defaulted to the pad value
Symbolic Views for variable length tensors i.e. GPT2
- A way to define shapes that have variable dimensions i.e. Can range from 0 to 100 shape = (Variable('x', 0, 100))
An n-dimensional view on a 1D data buffer with support for zero-copy
movement instructions and symbolic expression generation.
"""
def __init__(self, shape: Tuple[int,...],
strides: Optional[Tuple[int,...]]=None,
Expand Down Expand Up @@ -190,6 +190,42 @@ def shrink(self, arg: Tuple[Tuple[int, int],...]) -> View:
if self.mask is not None else None
return create_view(tuple([e-s for s,e in arg]), mask=nmsk)

def to_symbolic(self) -> Tuple[Expr, Expr]:
"""
Create symbolic expressions for the indeces and
boundaries of the final view.
"""
strides = [Lit(st) for st in self.strides]
mask = self.mask
offset = self.offset
idxs = [Symbol(f"idx{i}", Interval(0,s)) for i, s in enumerate(self.shape)]
iexpr = Lit(offset)
for idx, stride in zip(idxs, strides): iexpr += idx*stride
bexpr = []
if mask is None: mask = tuple([(0,s) for s in self.shape])
for idx, m in zip(idxs, mask):
bexpr.append(idx < Lit(m[1]))
bexpr.append(idx >= Lit(m[0]))
bexpr_ = bexpr[0]
for bnext in bexpr[1:]: bexpr_ = bexpr_.and_(bnext)
return iexpr, bexpr_

def render(self) -> Tuple[str,str]:
"""
Create two strings representing the indeces and
boundary expressions for the final view.
ex.)
vt = ViewTracker.from_shape((2,2))
vt = vt.pad(((1,1),(1,1))).render()
indexing expr: (-3 + (idx0 * 2)) + (idx1 * 1)
boundary expr: (((idx0 < 3) && (idx0 >= 1)) && (idx1 < 3)) && (idx1 >= 1)
"""
iexpr, bexpr = self.to_symbolic()
return render(iexpr), render(bexpr)

@staticmethod
def from_view(view: View): return create_view(view.shape, view.strides)

Expand Down
6 changes: 3 additions & 3 deletions test/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def test_render_pad(self):
strides = [Lit(st) for st in stride]
iexpr = idx0*strides[0]+off+idx1*strides[1]
self.assertEqual("(idx0 * 2 + -3) + (idx1 * 1)", render(iexpr))
a = idx0*(Lit(-1)) > Lit(0)
b = idx1*(Lit(-1)) > Lit(0)
a = idx0*(Lit(-1)) < Lit(0)
b = idx1*(Lit(-1)) < Lit(0)
c = idx0 < Lit(3)
d = idx1 < Lit(3)
vexpr = a.and_(b.and_(c).and_(d))
self.assertEqual("(idx0 * -1 > 0) && (((idx1 * -1 > 0) && (idx0 < 3)) && (idx1 < 3))", render(vexpr))
self.assertEqual("(idx0 * -1 < 0) && (((idx1 * -1 < 0) && (idx0 < 3)) && (idx1 < 3))", render(vexpr))


38 changes: 37 additions & 1 deletion test/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,40 @@ def test_pad_reshape_adjust_mask2(self):
vt = vt.reshape((40,1))
self.assertIsNone(vt.view.mask)
self.assertEqual((40,1), vt.shape)
print(vt)
print(vt)

def test_unmasked_1D_symbolic_render(self):
vt = ViewTracker.from_shape((3,))
i, b = vt.render()
self.assertEqual("0 + (idx0 * 1)", i)
self.assertEqual("(idx0 < 3) && (idx0 >= 0)", b)

def test_unmasked_2D_symbolic_render(self):
vt = ViewTracker.from_shape((3,3))
i, b = vt.render()
self.assertEqual("(0 + (idx0 * 3)) + (idx1 * 1)", i)
self.assertEqual("(((idx0 < 3) && (idx0 >= 0)) && (idx1 < 3)) && (idx1 >= 0)", b)

def test_shrink_symbolic_render(self):
vt = ViewTracker.from_shape((3,3))
vt = vt.shrink(((1,3),(2,3)))
i, b = vt.render()
self.assertEqual("(0 + (idx0 * 1)) + (idx1 * 0)", i)
self.assertEqual("(((idx0 < 2) && (idx0 >= 0)) && (idx1 < 1)) && (idx1 >= 0)", b)

def test_pad_symbolic_render(self):
vt = ViewTracker.from_shape((2,2))
vt = vt.pad(((1,1),(1,1)))
iexpr, vexpr = vt.render()
self.assertEqual("(-3 + (idx0 * 2)) + (idx1 * 1)", iexpr)
self.assertEqual("(((idx0 < 3) && (idx0 >= 1)) && (idx1 < 3)) && (idx1 >= 1)", vexpr)

def test_pad_shrink_symbolic_render(self):
vt = ViewTracker.from_shape((2,2))
vt = vt.pad(((1,1),(1,1)))
vt = vt.shrink(((1,3),(2,3)))
i, b = vt.render()
print(vt.view)
# shape=(2,1) stride=(1,0) mask=((0,2), (0,2))
self.assertEqual("(0 + (idx0 * 1)) + (idx1 * 0)", i)
self.assertEqual("(((idx0 < 2) && (idx0 >= 0)) && (idx1 < 2)) && (idx1 >= 0)", b)

0 comments on commit aa2cfbb

Please sign in to comment.