From 6b9be88412c9f9c673480d226cd13fb451e48e59 Mon Sep 17 00:00:00 2001 From: Kevin Kenyon Date: Sun, 28 Jul 2024 19:27:26 -0500 Subject: [PATCH] Add padding to lower and use symbolic for index and ternary --- shrimpgrad/engine/lower.py | 35 +++++++++++++++++++++++----------- shrimpgrad/engine/runner.py | 4 ++-- shrimpgrad/engine/scheduler.py | 4 ++-- shrimpgrad/future.py | 6 ++---- shrimpgrad/runtime/clang.py | 20 ++++++++++++++++--- shrimpgrad/runtime/ops.py | 2 +- shrimpgrad/tensor.py | 1 - shrimpgrad/view.py | 9 +++++---- test/test_tensor.py | 2 +- test/test_thunk.py | 3 +-- 10 files changed, 55 insertions(+), 31 deletions(-) diff --git a/shrimpgrad/engine/lower.py b/shrimpgrad/engine/lower.py index 254a6d3..53162fa 100644 --- a/shrimpgrad/engine/lower.py +++ b/shrimpgrad/engine/lower.py @@ -9,8 +9,9 @@ from shrimpgrad.engine.scheduler import FusedKernel from shrimpgrad.runtime.ops import BinaryOps, LoadOps, ReduceOps, TernaryOps, UnaryOps from shrimpgrad.util import prod +from shrimpgrad.view import ViewTracker -def alu2str(op:BinaryOps|UnaryOps) -> str: +def alu2str(op:BinaryOps|UnaryOps|TernaryOps) -> str: assert op in BinaryOps or op in UnaryOps, f"{op} is not a binary/unary alu op" if op is BinaryOps.ADD: return '+' if op is BinaryOps.CMPEQ: return '==' @@ -27,6 +28,7 @@ def alu2str(op:BinaryOps|UnaryOps) -> str: if op is UnaryOps.NEG: return '-' if op is UnaryOps.SIN: return 'sin' if op is UnaryOps.SQRT: return 'sqrt' + raise ValueError(f"{op} is not a valid alu op") class LowIR(Enum): GLOBAL = auto() @@ -86,7 +88,10 @@ class AddressNode(Node): idx: Tuple[LocalNode|int,...] stride: Tuple[int,...] step: int + vt: Optional[ViewTracker]=None def __repr__(self) -> str: + if self.vt is not None: + return f"{'ADDRESS':<15}{self.vt.render()[0]:<10}" addr = '' for idx, stride in zip(self.idx, self.stride): val = idx.name if isinstance(idx, LocalNode) else idx @@ -127,7 +132,7 @@ def __hash__(self): return self.hash @dataclass(frozen=True, eq=True) class ALUNode(Node): - alu: Union[BinaryOps, UnaryOps] + alu: BinaryOps|UnaryOps|TernaryOps dtype: DType def __repr__(self) -> str: operands = f"{' ':<10}".join([f"{str(operand.op)}" for operand in self.ancestors]) @@ -174,7 +179,7 @@ class LowIRGraph: def __init__(self): self.G: List[Node] = [] - def const(self, dtype: DType, val: ConstType) -> Node: + def const(self, dtype: DType, val: ConstType) -> ConstNode: self.G.append(node:=ConstNode(LowIR.CONST, (), dtype, val)) return node @@ -186,12 +191,15 @@ def local_var(self, name: str, dtype: DType, val: ConstNode|ALUNode) -> LocalNod self.G.append(node:=LocalNode(LowIR.LOCAL, (val, ), name, dtype)) return node - def address(self, idxs: List[LocalNode|int], strides: Tuple[int,...], step: int): + def address(self, idxs: List[LocalNode] | List[int], strides: Tuple[int,...], step: int, vt: Optional[ViewTracker]=None) -> AddressNode: + if vt is not None: + self.G.append(node:=AddressNode(LowIR.ADDRESS, (), tuple(idxs), (), 0, vt)) + return node self.G.append(node:=AddressNode(LowIR.ADDRESS, (), tuple(idxs), strides, step)) return node - def load(self, node: Union[GlobalNode, LocalNode], location: Optional[AddressNode|OffsetNode]) -> Node: - self.G.append(node:=LoadNode(LowIR.LOAD, (node, location))) + def load(self, inp: GlobalNode|LocalNode, location: Optional[AddressNode|OffsetNode]) -> LoadNode: + self.G.append(node:=LoadNode(LowIR.LOAD, (inp, location))) return node def accumulator(self, alu: BinaryOps, @@ -201,7 +209,7 @@ def accumulator(self, alu: BinaryOps, return node def store(self, lhs: Union[GlobalNode, LocalNode], - address: AddressNode|OffsetNode|None, + address: AddressNode|OffsetNode|None|ViewTracker, rhs: Union[LoadNode, ConstNode, LocalNode]) -> Node: self.G.append(node:=StoreNode(LowIR.STORE, (lhs, address, rhs))) return node @@ -289,7 +297,7 @@ def lower_alu(self, alu: BinaryOps|TernaryOps|UnaryOps, *loads): # TODO: Deal with dtype return self.g.alu(alu, self.dtype, *loads) - def lower_store(self, g: GlobalNode|LocalNode, addr: AddressNode|OffsetNode|None, value: Node) -> StoreNode: + def lower_store(self, g: GlobalNode|LocalNode, addr: AddressNode|OffsetNode|None|ViewTracker, value: Node) -> StoreNode: if g.__class__ is LocalNode: return self.g.store(g, None, value) return self.g.store(g, addr, value) if g.ptr else self.g.store(g, None, value) @@ -419,7 +427,6 @@ def lower_rop(self, self.lower_store(out, out_off, acc) self.g.inc(off) - def lower_begin_for(self, s: int, e: int) -> Tuple[BeginLoopNode, LocalNode]: c0 = self.g.const(dtypes.int32, s) # start c1 = self.g.const(dtypes.int32, e) # end @@ -430,7 +437,7 @@ def lower_begin_for(self, s: int, e: int) -> Tuple[BeginLoopNode, LocalNode]: return loop, idx # TODO: Loop unrolling (but not ncessary once we gen for GPU) - def lower_start_loops(self, ndim:int, shape: Tuple[int,...]): + def lower_start_loops(self, ndim:int, shape: Tuple[int,...]) -> Tuple[List[BeginLoopNode], List[LocalNode]]: loops, idxs = [], [] for dim in range(ndim): # Create two constant values for the loop index @@ -471,11 +478,17 @@ def lower_single_op_kernel(self, fused_kernel: FusedKernel): gout = self.lower_io(output, is_input=False) loops, idxs = self.lower_start_loops(output.vt.ndim, output.vt.shape) val = self.g.const(output.buff.dtype, arg) - print(output.vt.view.mask) addr0 = self.g.address(idxs, output.vt.strides, 1) self.lower_store(gout, addr0, val) self.lower_end_loops(loops) return + if op is LoadOps.PAD: + gout = self.lower_io(output, is_input=False) + loops, idxs = self.lower_start_loops(output.vt.ndim, output.vt.shape) + val = self.g.const(output.buff.dtype, arg) + addr = self.g.address(idxs, (), 0, output.vt) + self.lower_store(gout, addr, val) + self.lower_end_loops(loops) if op in LoadOps and op is not LoadOps.ASSIGN: self.lower_io(output, is_input=True) return diff --git a/shrimpgrad/engine/runner.py b/shrimpgrad/engine/runner.py index b104785..ac1bf0e 100644 --- a/shrimpgrad/engine/runner.py +++ b/shrimpgrad/engine/runner.py @@ -60,7 +60,7 @@ def name_kernels(kernels: List[FusedKernel]) -> List[str]: func_name for func_name in ( '_'.join( [op.name.lower() for op in s.computation.ops] + - ['0' if s.computation.ops[0] in [LoadOps.CONST,LoadOps.CONTIGUOUS] else '_'.join( + ['0' if s.computation.ops[0] in [LoadOps.CONST,LoadOps.CONTIGUOUS, LoadOps.PAD] else '_'.join( map(str, s.computation.ins[0][0].vt.shape))] + ['_'.join(map(str, s.computation.out[-1].vt.shape))] + [str(i)] @@ -71,7 +71,7 @@ def _gen_load_kernels(schedule: List[FusedKernel]) -> Tuple[List[BufferCopy], Li l, u = [], [] for fk in schedule: c = fk.computation - if c.ops[0] == LoadOps.CONST and c.out[0].buff.allocated: continue + # if c.ops[0] == LoadOps.CONST and c.out[0].buff.allocated: continue if len(c.ins) == 1 and c.ops[0] == LoadOps.COPY: l.append(BufferCopy(c.out[0], c.ins[0][0], c.args[0])) else: u.append(fk) diff --git a/shrimpgrad/engine/scheduler.py b/shrimpgrad/engine/scheduler.py index c229f4e..e5a677e 100644 --- a/shrimpgrad/engine/scheduler.py +++ b/shrimpgrad/engine/scheduler.py @@ -41,7 +41,7 @@ def schedule_fused(self) -> List[FusedKernel]: # Assign thunks can be realized in forward passes as their backing buffer # is realized. This is useful for optimizer updates to params. is_realized = thunk.realized is not None - if is_realized and thunk._op is not LoadOps.ASSIGN: continue + if is_realized and thunk._op is not LoadOps.ASSIGN and thunk._op is not LoadOps.PAD: continue if thunk._op is LoadOps.ASSIGN and thunk._operands[1].realized is not None : assert is_realized, 'assign target must be realized' continue @@ -53,7 +53,7 @@ def schedule_fused(self) -> List[FusedKernel]: fused_unrealized = [] # If a thunk in this fusion group is already realized, there's no # need to schedule it again, it's backing buffer already has - # the result that we would compute again causing a doubling of the buffer . + # the result that we would compute again causing a doubling of the buffer # This is vital when running # backward passes because backward functions sometimes use input thunks that # become realized in the forward pass. When backward includes those thunks diff --git a/shrimpgrad/future.py b/shrimpgrad/future.py index feaaa74..abc5268 100644 --- a/shrimpgrad/future.py +++ b/shrimpgrad/future.py @@ -114,10 +114,7 @@ def expand(self, shape: Tuple[int,...]) -> Thunk: return create_thunk(self.device, self.dtype, self.vt.expand(shape), (), base=self.base) def pad(self, pad_width: Tuple[Tuple[int, int],...], value: ConstType=0.0): - thunk = create_thunk(self.device, self.dtype, self.vt.pad(pad_width), (), base=self.base, arg=value) - # TODO: Hack to fix a test so I can commit - self.base.buff = Buffer(self.device, prod(thunk.shape), thunk.dtype) - return thunk + return create_thunk(self.device, self.dtype, self.vt.pad(pad_width), (), op=LoadOps.PAD, base=self.base, arg=value) def shrink(self, shrink_width: Tuple[Tuple[int, int],...]): return create_thunk(self.device, self.dtype, self.vt.shrink(shrink_width), (), base=self.base) @@ -129,6 +126,7 @@ def cast(self, dtype: DType) -> Thunk: def load_const(val: ConstType, shape: Tuple[int,...], dtype: DType, device: Device): assert isinstance(val, ConstType), f'load_const expects const val, got {val}' thunk = Thunk.loadop(LoadOps.CONST, (), dtype, device, arg=val) + # Allocate so we don't schedule the load thunk.buff.allocate(with_data=val) return thunk.reshape((1,)*len(shape)).expand(shape) diff --git a/shrimpgrad/runtime/clang.py b/shrimpgrad/runtime/clang.py index 6c7150a..9c62c15 100644 --- a/shrimpgrad/runtime/clang.py +++ b/shrimpgrad/runtime/clang.py @@ -1,10 +1,12 @@ from __future__ import annotations import ctypes, pathlib, tempfile, subprocess -from typing import DefaultDict, Dict, List +from typing import DefaultDict, Dict, List, Tuple, cast from shrimpgrad.device import Device, Compiler, Jitable, MallocAllocator, MemBuffer, Renderer, Runtime from shrimpgrad.dtype import dtypes -from shrimpgrad.engine.lower import ALUNode, ConstNode, GlobalNode, LocalNode, LowIR, LowIRGraph, alu2str +from shrimpgrad.engine.lower import ALUNode, AddressNode, ConstNode, GlobalNode, LocalNode, LowIR, LowIRGraph, alu2str from shrimpgrad.runtime.ops import UnaryOps, BinaryOps, TernaryOps +from shrimpgrad.symbolic import Expr, render +from shrimpgrad.view import ViewTracker c_alu = { UnaryOps.LOG2: lambda x: f'log2({x})', @@ -117,7 +119,6 @@ def param2src(self, g): param += ' ' + g[0] return param - def _name(self): return f"f_{id(self.irgs[0])}" if self.func_name is None else self.func_name @@ -169,6 +170,12 @@ def gen(self): self.src.append(f"{self.spaces}for (; {s.name} < {e}; {s.name}++) {{ ") self.indent += 2 elif instr.op is LowIR.ADDRESS: + instr = cast(AddressNode, instr) + if instr.vt is not None: + exprs = instr.vt.to_symbolic([ix.name for ix in instr.idx if isinstance(ix, LocalNode)]) + self.instr_to_src[instr] = exprs + i+=1 + continue addr = '' for idx, stride in zip(instr.idx, instr.stride): val = idx.name if isinstance(idx, LocalNode) else idx @@ -204,12 +211,17 @@ def gen(self): return self def store(self, instr): + bexpr = None idx = instr.ancestors[1] if isinstance(idx, ConstNode): idx = idx.val else: if idx is not None: idx = self.instr_to_src[idx] + if isinstance(idx, Tuple): + idx = cast(Tuple[Expr,Expr],idx) + iexpr, bexpr = render(idx[0]), render(idx[1]) + idx = iexpr lhs = instr.ancestors[0] rhs = self.instr_to_src[instr.ancestors[2]] @@ -217,6 +229,8 @@ def store(self, instr): if isinstance(lhs, GlobalNode): if lhs.ptr: r = f"{lhs.name}[{idx if idx is not None else 0}] = {rhs}" + if bexpr is not None: + r = f"if (!({bexpr})) {{ {r}; }}" else: r = f"*{lhs.name} = {rhs}" else: diff --git a/shrimpgrad/runtime/ops.py b/shrimpgrad/runtime/ops.py index 8918171..eb3547f 100644 --- a/shrimpgrad/runtime/ops.py +++ b/shrimpgrad/runtime/ops.py @@ -7,7 +7,7 @@ class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); M class TernaryOps(Enum): WHERE = auto(); MULACC = auto() class ReduceOps(Enum): SUM = auto(); MAX = auto() class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() -class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto() +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); PAD = auto(); ASSIGN = auto() Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]] diff --git a/shrimpgrad/tensor.py b/shrimpgrad/tensor.py index 93e1351..9d5073e 100644 --- a/shrimpgrad/tensor.py +++ b/shrimpgrad/tensor.py @@ -376,7 +376,6 @@ def analyze(self): grad_data = self.grad.data().flatten()[0:5] grad_alloc = grad_buffer.allocated print(f"{op = } {is_view = } alloc={buffer.allocated} {buffer_addr = } alloc={grad_alloc} {grad_buffer_addr = } {grad_data = }") - def __repr__(self): return f"" def __str__(self): return self.__repr__() diff --git a/shrimpgrad/view.py b/shrimpgrad/view.py index aae72a9..b8f9b2a 100644 --- a/shrimpgrad/view.py +++ b/shrimpgrad/view.py @@ -57,10 +57,10 @@ def pad(self, pad_width: Tuple[Tuple[int,int], ...]) -> ViewTracker: def shrink(self, arg: Tuple[Tuple[int,int], ...]) -> ViewTracker: return ViewTracker.from_views(self.views + [self.view.shrink(arg)]) - def to_symbolic(self) -> Tuple[Expr, Expr]: + def to_symbolic(self, idx_names: Optional[List[str]]=None) -> Tuple[Expr, Expr]: """ Create a symbolic expression for the final view. """ - return self.view.to_symbolic() + return self.view.to_symbolic(idx_names) def render(self) -> Tuple[str, str]: """ Render strings for the symbolic expression for the final view. @@ -190,7 +190,7 @@ def shrink(self, arg: Tuple[Tuple[int, int],...]) -> View: if self.mask is not None else None return create_view(tuple([e-s for s,e in arg]), mask=nmsk) - def to_symbolic(self) -> Tuple[Expr, Expr]: + def to_symbolic(self, idx_names: Optional[List[str]]=None) -> Tuple[Expr, Expr]: """ Create symbolic expressions for the indeces and boundaries of the final view. @@ -198,7 +198,8 @@ def to_symbolic(self) -> Tuple[Expr, Expr]: strides = [Lit(st) for st in self.strides] mask = self.mask offset = self.offset - idxs = [Symbol(f"idx{i}", Interval(0,s)) for i, s in enumerate(self.shape)] + if idx_names is None: idxs = [Symbol(f"idx{i}", Interval(0,s)) for i, s in enumerate(self.shape)] + else: idxs = [Symbol(idx_name, Interval(0,s)) for idx_name, s in zip(idx_names, self.shape)] iexpr = Lit(offset) for idx, stride in zip(idxs, strides): iexpr += idx*stride bexpr = [] diff --git a/test/test_tensor.py b/test/test_tensor.py index 1dd761c..d207c4b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -86,7 +86,7 @@ def test_const_load_throughput(self): class TestPadAndShrink(unittest.TestCase): def test_pad(self): t = Tensor.full((2,2), 3.0) + np.testing.assert_allclose(t.numpy(), 3.0) t = t.pad(((1, 1), (1, 1))) with Knobs(DEBUG=4): print(t.numpy()) - print(t.thunk.vt) diff --git a/test/test_thunk.py b/test/test_thunk.py index a19c5d9..4ae133f 100644 --- a/test/test_thunk.py +++ b/test/test_thunk.py @@ -25,7 +25,6 @@ def test_load_const_reshape(self): self.assertTrue(src.realized is not None) self.assertEqual(src.device, ClangDevice()) - def test_load_empty_double_reshape(self): x = Tensor((2,2), [2.0]*4).reshape(*(1,2,2,)).reshape(*(2,2)) self.assertEqual(x.thunk._op, None) @@ -41,6 +40,7 @@ def test_pad(self): t1 = t.pad(((1,1),(1,1),(1,1)), 0.0) self.assertEqual((4,4,4), t1.shape) self.assertEqual(0.0, t1.arg) + self.assertEqual(t1._op, LoadOps.PAD) class TestLoads(unittest.TestCase): def test_load_const_function_nd(self): @@ -63,7 +63,6 @@ def test_load_const_function_scalar(self): self.assertEqual(4, t.base.buff.nbytes) self.assertEqual(True, t.base.buff.allocated) self.assertEqual(ClangDevice(), t.device) - def test_load_const(self): x = Tensor.full((2,2,2), 3.0)