Skip to content

Commit

Permalink
Add padding to lower and use symbolic for index and ternary
Browse files Browse the repository at this point in the history
  • Loading branch information
kvkenyon committed Jul 29, 2024
1 parent 8819536 commit 6b9be88
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 31 deletions.
35 changes: 24 additions & 11 deletions shrimpgrad/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from shrimpgrad.engine.scheduler import FusedKernel
from shrimpgrad.runtime.ops import BinaryOps, LoadOps, ReduceOps, TernaryOps, UnaryOps
from shrimpgrad.util import prod
from shrimpgrad.view import ViewTracker

def alu2str(op:BinaryOps|UnaryOps) -> str:
def alu2str(op:BinaryOps|UnaryOps|TernaryOps) -> str:
assert op in BinaryOps or op in UnaryOps, f"{op} is not a binary/unary alu op"
if op is BinaryOps.ADD: return '+'
if op is BinaryOps.CMPEQ: return '=='
Expand All @@ -27,6 +28,7 @@ def alu2str(op:BinaryOps|UnaryOps) -> str:
if op is UnaryOps.NEG: return '-'
if op is UnaryOps.SIN: return 'sin'
if op is UnaryOps.SQRT: return 'sqrt'
raise ValueError(f"{op} is not a valid alu op")

class LowIR(Enum):
GLOBAL = auto()
Expand Down Expand Up @@ -86,7 +88,10 @@ class AddressNode(Node):
idx: Tuple[LocalNode|int,...]
stride: Tuple[int,...]
step: int
vt: Optional[ViewTracker]=None
def __repr__(self) -> str:
if self.vt is not None:
return f"{'ADDRESS':<15}{self.vt.render()[0]:<10}"
addr = ''
for idx, stride in zip(self.idx, self.stride):
val = idx.name if isinstance(idx, LocalNode) else idx
Expand Down Expand Up @@ -127,7 +132,7 @@ def __hash__(self): return self.hash

@dataclass(frozen=True, eq=True)
class ALUNode(Node):
alu: Union[BinaryOps, UnaryOps]
alu: BinaryOps|UnaryOps|TernaryOps
dtype: DType
def __repr__(self) -> str:
operands = f"{' ':<10}".join([f"{str(operand.op)}" for operand in self.ancestors])
Expand Down Expand Up @@ -174,7 +179,7 @@ class LowIRGraph:
def __init__(self):
self.G: List[Node] = []

def const(self, dtype: DType, val: ConstType) -> Node:
def const(self, dtype: DType, val: ConstType) -> ConstNode:
self.G.append(node:=ConstNode(LowIR.CONST, (), dtype, val))
return node

Expand All @@ -186,12 +191,15 @@ def local_var(self, name: str, dtype: DType, val: ConstNode|ALUNode) -> LocalNod
self.G.append(node:=LocalNode(LowIR.LOCAL, (val, ), name, dtype))
return node

def address(self, idxs: List[LocalNode|int], strides: Tuple[int,...], step: int):
def address(self, idxs: List[LocalNode] | List[int], strides: Tuple[int,...], step: int, vt: Optional[ViewTracker]=None) -> AddressNode:
if vt is not None:
self.G.append(node:=AddressNode(LowIR.ADDRESS, (), tuple(idxs), (), 0, vt))
return node
self.G.append(node:=AddressNode(LowIR.ADDRESS, (), tuple(idxs), strides, step))
return node

def load(self, node: Union[GlobalNode, LocalNode], location: Optional[AddressNode|OffsetNode]) -> Node:
self.G.append(node:=LoadNode(LowIR.LOAD, (node, location)))
def load(self, inp: GlobalNode|LocalNode, location: Optional[AddressNode|OffsetNode]) -> LoadNode:
self.G.append(node:=LoadNode(LowIR.LOAD, (inp, location)))
return node

def accumulator(self, alu: BinaryOps,
Expand All @@ -201,7 +209,7 @@ def accumulator(self, alu: BinaryOps,
return node

def store(self, lhs: Union[GlobalNode, LocalNode],
address: AddressNode|OffsetNode|None,
address: AddressNode|OffsetNode|None|ViewTracker,
rhs: Union[LoadNode, ConstNode, LocalNode]) -> Node:
self.G.append(node:=StoreNode(LowIR.STORE, (lhs, address, rhs)))
return node
Expand Down Expand Up @@ -289,7 +297,7 @@ def lower_alu(self, alu: BinaryOps|TernaryOps|UnaryOps, *loads):
# TODO: Deal with dtype
return self.g.alu(alu, self.dtype, *loads)

def lower_store(self, g: GlobalNode|LocalNode, addr: AddressNode|OffsetNode|None, value: Node) -> StoreNode:
def lower_store(self, g: GlobalNode|LocalNode, addr: AddressNode|OffsetNode|None|ViewTracker, value: Node) -> StoreNode:
if g.__class__ is LocalNode: return self.g.store(g, None, value)
return self.g.store(g, addr, value) if g.ptr else self.g.store(g, None, value)

Expand Down Expand Up @@ -419,7 +427,6 @@ def lower_rop(self,
self.lower_store(out, out_off, acc)
self.g.inc(off)


def lower_begin_for(self, s: int, e: int) -> Tuple[BeginLoopNode, LocalNode]:
c0 = self.g.const(dtypes.int32, s) # start
c1 = self.g.const(dtypes.int32, e) # end
Expand All @@ -430,7 +437,7 @@ def lower_begin_for(self, s: int, e: int) -> Tuple[BeginLoopNode, LocalNode]:
return loop, idx

# TODO: Loop unrolling (but not ncessary once we gen for GPU)
def lower_start_loops(self, ndim:int, shape: Tuple[int,...]):
def lower_start_loops(self, ndim:int, shape: Tuple[int,...]) -> Tuple[List[BeginLoopNode], List[LocalNode]]:
loops, idxs = [], []
for dim in range(ndim):
# Create two constant values for the loop index
Expand Down Expand Up @@ -471,11 +478,17 @@ def lower_single_op_kernel(self, fused_kernel: FusedKernel):
gout = self.lower_io(output, is_input=False)
loops, idxs = self.lower_start_loops(output.vt.ndim, output.vt.shape)
val = self.g.const(output.buff.dtype, arg)
print(output.vt.view.mask)
addr0 = self.g.address(idxs, output.vt.strides, 1)
self.lower_store(gout, addr0, val)
self.lower_end_loops(loops)
return
if op is LoadOps.PAD:
gout = self.lower_io(output, is_input=False)
loops, idxs = self.lower_start_loops(output.vt.ndim, output.vt.shape)
val = self.g.const(output.buff.dtype, arg)
addr = self.g.address(idxs, (), 0, output.vt)
self.lower_store(gout, addr, val)
self.lower_end_loops(loops)
if op in LoadOps and op is not LoadOps.ASSIGN:
self.lower_io(output, is_input=True)
return
Expand Down
4 changes: 2 additions & 2 deletions shrimpgrad/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def name_kernels(kernels: List[FusedKernel]) -> List[str]:
func_name for func_name in (
'_'.join(
[op.name.lower() for op in s.computation.ops] +
['0' if s.computation.ops[0] in [LoadOps.CONST,LoadOps.CONTIGUOUS] else '_'.join(
['0' if s.computation.ops[0] in [LoadOps.CONST,LoadOps.CONTIGUOUS, LoadOps.PAD] else '_'.join(
map(str, s.computation.ins[0][0].vt.shape))] +
['_'.join(map(str, s.computation.out[-1].vt.shape))] +
[str(i)]
Expand All @@ -71,7 +71,7 @@ def _gen_load_kernels(schedule: List[FusedKernel]) -> Tuple[List[BufferCopy], Li
l, u = [], []
for fk in schedule:
c = fk.computation
if c.ops[0] == LoadOps.CONST and c.out[0].buff.allocated: continue
# if c.ops[0] == LoadOps.CONST and c.out[0].buff.allocated: continue
if len(c.ins) == 1 and c.ops[0] == LoadOps.COPY:
l.append(BufferCopy(c.out[0], c.ins[0][0], c.args[0]))
else: u.append(fk)
Expand Down
4 changes: 2 additions & 2 deletions shrimpgrad/engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def schedule_fused(self) -> List[FusedKernel]:
# Assign thunks can be realized in forward passes as their backing buffer
# is realized. This is useful for optimizer updates to params.
is_realized = thunk.realized is not None
if is_realized and thunk._op is not LoadOps.ASSIGN: continue
if is_realized and thunk._op is not LoadOps.ASSIGN and thunk._op is not LoadOps.PAD: continue
if thunk._op is LoadOps.ASSIGN and thunk._operands[1].realized is not None :
assert is_realized, 'assign target must be realized'
continue
Expand All @@ -53,7 +53,7 @@ def schedule_fused(self) -> List[FusedKernel]:
fused_unrealized = []
# If a thunk in this fusion group is already realized, there's no
# need to schedule it again, it's backing buffer already has
# the result that we would compute again causing a doubling of the buffer .
# the result that we would compute again causing a doubling of the buffer
# This is vital when running
# backward passes because backward functions sometimes use input thunks that
# become realized in the forward pass. When backward includes those thunks
Expand Down
6 changes: 2 additions & 4 deletions shrimpgrad/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def expand(self, shape: Tuple[int,...]) -> Thunk:
return create_thunk(self.device, self.dtype, self.vt.expand(shape), (), base=self.base)

def pad(self, pad_width: Tuple[Tuple[int, int],...], value: ConstType=0.0):
thunk = create_thunk(self.device, self.dtype, self.vt.pad(pad_width), (), base=self.base, arg=value)
# TODO: Hack to fix a test so I can commit
self.base.buff = Buffer(self.device, prod(thunk.shape), thunk.dtype)
return thunk
return create_thunk(self.device, self.dtype, self.vt.pad(pad_width), (), op=LoadOps.PAD, base=self.base, arg=value)

def shrink(self, shrink_width: Tuple[Tuple[int, int],...]):
return create_thunk(self.device, self.dtype, self.vt.shrink(shrink_width), (), base=self.base)
Expand All @@ -129,6 +126,7 @@ def cast(self, dtype: DType) -> Thunk:
def load_const(val: ConstType, shape: Tuple[int,...], dtype: DType, device: Device):
assert isinstance(val, ConstType), f'load_const expects const val, got {val}'
thunk = Thunk.loadop(LoadOps.CONST, (), dtype, device, arg=val)
# Allocate so we don't schedule the load
thunk.buff.allocate(with_data=val)
return thunk.reshape((1,)*len(shape)).expand(shape)

Expand Down
20 changes: 17 additions & 3 deletions shrimpgrad/runtime/clang.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations
import ctypes, pathlib, tempfile, subprocess
from typing import DefaultDict, Dict, List
from typing import DefaultDict, Dict, List, Tuple, cast
from shrimpgrad.device import Device, Compiler, Jitable, MallocAllocator, MemBuffer, Renderer, Runtime
from shrimpgrad.dtype import dtypes
from shrimpgrad.engine.lower import ALUNode, ConstNode, GlobalNode, LocalNode, LowIR, LowIRGraph, alu2str
from shrimpgrad.engine.lower import ALUNode, AddressNode, ConstNode, GlobalNode, LocalNode, LowIR, LowIRGraph, alu2str
from shrimpgrad.runtime.ops import UnaryOps, BinaryOps, TernaryOps
from shrimpgrad.symbolic import Expr, render
from shrimpgrad.view import ViewTracker

c_alu = {
UnaryOps.LOG2: lambda x: f'log2({x})',
Expand Down Expand Up @@ -117,7 +119,6 @@ def param2src(self, g):
param += ' ' + g[0]
return param


def _name(self):
return f"f_{id(self.irgs[0])}" if self.func_name is None else self.func_name

Expand Down Expand Up @@ -169,6 +170,12 @@ def gen(self):
self.src.append(f"{self.spaces}for (; {s.name} < {e}; {s.name}++) {{ ")
self.indent += 2
elif instr.op is LowIR.ADDRESS:
instr = cast(AddressNode, instr)
if instr.vt is not None:
exprs = instr.vt.to_symbolic([ix.name for ix in instr.idx if isinstance(ix, LocalNode)])
self.instr_to_src[instr] = exprs
i+=1
continue
addr = ''
for idx, stride in zip(instr.idx, instr.stride):
val = idx.name if isinstance(idx, LocalNode) else idx
Expand Down Expand Up @@ -204,19 +211,26 @@ def gen(self):
return self

def store(self, instr):
bexpr = None
idx = instr.ancestors[1]
if isinstance(idx, ConstNode):
idx = idx.val
else:
if idx is not None:
idx = self.instr_to_src[idx]
if isinstance(idx, Tuple):
idx = cast(Tuple[Expr,Expr],idx)
iexpr, bexpr = render(idx[0]), render(idx[1])
idx = iexpr

lhs = instr.ancestors[0]
rhs = self.instr_to_src[instr.ancestors[2]]

if isinstance(lhs, GlobalNode):
if lhs.ptr:
r = f"{lhs.name}[{idx if idx is not None else 0}] = {rhs}"
if bexpr is not None:
r = f"if (!({bexpr})) {{ {r}; }}"
else:
r = f"*{lhs.name} = {rhs}"
else:
Expand Down
2 changes: 1 addition & 1 deletion shrimpgrad/runtime/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); M
class TernaryOps(Enum): WHERE = auto(); MULACC = auto()
class ReduceOps(Enum): SUM = auto(); MAX = auto()
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto()
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto()
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); PAD = auto(); ASSIGN = auto()

Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
Expand Down
1 change: 0 additions & 1 deletion shrimpgrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def analyze(self):
grad_data = self.grad.data().flatten()[0:5]
grad_alloc = grad_buffer.allocated
print(f"{op = } {is_view = } alloc={buffer.allocated} {buffer_addr = } alloc={grad_alloc} {grad_buffer_addr = } {grad_data = }")


def __repr__(self): return f"<Tensor {self.thunk!r} on {self.device} with grad {(self.grad.thunk if self.grad is not None else None)!r}>"
def __str__(self): return self.__repr__()
Expand Down
9 changes: 5 additions & 4 deletions shrimpgrad/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def pad(self, pad_width: Tuple[Tuple[int,int], ...]) -> ViewTracker:
def shrink(self, arg: Tuple[Tuple[int,int], ...]) -> ViewTracker:
return ViewTracker.from_views(self.views + [self.view.shrink(arg)])

def to_symbolic(self) -> Tuple[Expr, Expr]:
def to_symbolic(self, idx_names: Optional[List[str]]=None) -> Tuple[Expr, Expr]:
""" Create a symbolic expression for the final view.
"""
return self.view.to_symbolic()
return self.view.to_symbolic(idx_names)

def render(self) -> Tuple[str, str]:
""" Render strings for the symbolic expression for the final view.
Expand Down Expand Up @@ -190,15 +190,16 @@ def shrink(self, arg: Tuple[Tuple[int, int],...]) -> View:
if self.mask is not None else None
return create_view(tuple([e-s for s,e in arg]), mask=nmsk)

def to_symbolic(self) -> Tuple[Expr, Expr]:
def to_symbolic(self, idx_names: Optional[List[str]]=None) -> Tuple[Expr, Expr]:
"""
Create symbolic expressions for the indeces and
boundaries of the final view.
"""
strides = [Lit(st) for st in self.strides]
mask = self.mask
offset = self.offset
idxs = [Symbol(f"idx{i}", Interval(0,s)) for i, s in enumerate(self.shape)]
if idx_names is None: idxs = [Symbol(f"idx{i}", Interval(0,s)) for i, s in enumerate(self.shape)]
else: idxs = [Symbol(idx_name, Interval(0,s)) for idx_name, s in zip(idx_names, self.shape)]
iexpr = Lit(offset)
for idx, stride in zip(idxs, strides): iexpr += idx*stride
bexpr = []
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_const_load_throughput(self):
class TestPadAndShrink(unittest.TestCase):
def test_pad(self):
t = Tensor.full((2,2), 3.0)
np.testing.assert_allclose(t.numpy(), 3.0)
t = t.pad(((1, 1), (1, 1)))
with Knobs(DEBUG=4):
print(t.numpy())
print(t.thunk.vt)
3 changes: 1 addition & 2 deletions test/test_thunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def test_load_const_reshape(self):
self.assertTrue(src.realized is not None)
self.assertEqual(src.device, ClangDevice())


def test_load_empty_double_reshape(self):
x = Tensor((2,2), [2.0]*4).reshape(*(1,2,2,)).reshape(*(2,2))
self.assertEqual(x.thunk._op, None)
Expand All @@ -41,6 +40,7 @@ def test_pad(self):
t1 = t.pad(((1,1),(1,1),(1,1)), 0.0)
self.assertEqual((4,4,4), t1.shape)
self.assertEqual(0.0, t1.arg)
self.assertEqual(t1._op, LoadOps.PAD)

class TestLoads(unittest.TestCase):
def test_load_const_function_nd(self):
Expand All @@ -63,7 +63,6 @@ def test_load_const_function_scalar(self):
self.assertEqual(4, t.base.buff.nbytes)
self.assertEqual(True, t.base.buff.allocated)
self.assertEqual(ClangDevice(), t.device)


def test_load_const(self):
x = Tensor.full((2,2,2), 3.0)
Expand Down

0 comments on commit 6b9be88

Please sign in to comment.