diff --git a/compiler/accelerators/snax.py b/compiler/accelerators/snax.py index 7f31ad20..38f1982d 100644 --- a/compiler/accelerators/snax.py +++ b/compiler/accelerators/snax.py @@ -169,6 +169,11 @@ def _generate_streamer_setup_vals( cst = arith.Constant.from_int_and_width(stride.data, i32) result.append(([cst], cst.result)) + # address remap: + if StreamerOpts.HasAddressRemap in streamer.opts: + c0 = arith.Constant.from_int_and_width(0, i32) + result.append(([c0], c0.result)) + # channel mask option if StreamerOpts.HasChannelMask in streamer.opts: if is_zero_pattern: @@ -189,6 +194,11 @@ def _generate_streamer_setup_vals( c1 = arith.Constant.from_int_and_width(1, i32) result.append(([c1], c1.result)) + for operand, streamer in enumerate(self.streamer_config.data.streamers): + if StreamerOpts.HasBroadcast in streamer.opts: + c1 = arith.Constant.from_int_and_width(1, i32) + result.append(([c1], c1.result)) + return result def get_streamer_setup_fields(self) -> Sequence[str]: @@ -206,6 +216,8 @@ def get_streamer_setup_fields(self) -> Sequence[str]: # temporal strides result.extend([f"{name}_tstride_{i}" for i in range(streamer.temporal_dim)]) # options + if StreamerOpts.HasAddressRemap in streamer.opts: + result.append(f"{name}_address_remap") if StreamerOpts.HasChannelMask in streamer.opts: result.append(f"{name}_channel_mask") @@ -216,6 +228,12 @@ def get_streamer_setup_fields(self) -> Sequence[str]: if StreamerOpts.HasTranspose in streamer.opts: result.append(f"{name}_transpose") + for streamer, name in zip( + self.streamer_config.data.streamers, self.streamer_names + ): + if StreamerOpts.HasBroadcast in streamer.opts: + result.append(f"{name}_broadcast") + return result def get_streamer_launch_fields(self) -> Sequence[str]: diff --git a/compiler/accelerators/snax_gemmx.py b/compiler/accelerators/snax_gemmx.py index 9a3d997e..be1568c0 100644 --- a/compiler/accelerators/snax_gemmx.py +++ b/compiler/accelerators/snax_gemmx.py @@ -26,31 +26,33 @@ [ Streamer( # A StreamerType.Reader, - temporal_dims=("n", "n", "n", "n", "n", "n"), + temporal_dims=("n", "n", "n", "n", "n", "n", "n"), spatial_dims=("n",), - opts=(StreamerOpts.HasTranspose,), + opts=(StreamerOpts.HasTranspose, StreamerOpts.HasAddressRemap), ), Streamer( # B StreamerType.Reader, - temporal_dims=("n", "n", "n"), + temporal_dims=("n", "n", "n", "n", "n", "n", "n"), spatial_dims=("n",), - opts=(StreamerOpts.HasTranspose,), + opts=(StreamerOpts.HasTranspose, StreamerOpts.HasAddressRemap), ), Streamer( # D8 StreamerType.Writer, - temporal_dims=("r", "n", "n"), + temporal_dims=("n", "n", "n"), spatial_dims=("n",), + opts=(StreamerOpts.HasAddressRemap,), ), Streamer( # C StreamerType.Reader, - temporal_dims=("r", "n", "n"), - spatial_dims=("n",), - opts=(StreamerOpts.HasChannelMask,), + temporal_dims=("n", "n", "n", "n", "n", "n", "n"), + spatial_dims=("n", "n"), + opts=(StreamerOpts.HasChannelMask, StreamerOpts.HasAddressRemap, StreamerOpts.HasBroadcast), ), Streamer( # D32 StreamerType.Writer, - temporal_dims=("r", "n", "n"), - spatial_dims=("n",), + temporal_dims=("n", "n", "n", "n", "n", "n", "n"), + spatial_dims=("n", "n"), + opts=(StreamerOpts.HasAddressRemap,), ), ], ) @@ -67,6 +69,7 @@ class SNAXGEMMXAccelerator( supported_kernels = ( SupportedKernel(kernel.QMacOp, (i8, i8, i32, i32, i32)), + SupportedKernel(kernel.MacOp, (i8, i8, i32)), SupportedKernel(kernel.AddOp, (i32, i32, i32)), SupportedKernel(kernel.RescaleOp, (i32, i8)), ) @@ -162,10 +165,6 @@ def _generate_setup_vals( c0 = arith.Constant.from_int_and_width(0, 32) c1 = arith.Constant.from_int_and_width(1, 32) - knm: list = [ - (((cst := arith.Constant.from_int_and_width(val.data, 32)),), cst.result) - for val in op.stride_patterns.data[0].upper_bounds - ] streamer_setup_vals = list(self._generate_streamer_setup_vals(op)) @@ -174,6 +173,16 @@ def _generate_setup_vals( assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp) if isinstance(qmac := generic_op.body.block.first_op, kernel.QMacOp): + # compute knm: fix n = 1 + n = 1 + m = prod(x.data for x in op.stride_patterns.data[-1].upper_bounds) // n + k = prod(x.data for x in op.stride_patterns.data[0].upper_bounds) // m + + knm: list = [ + (((cst := arith.Constant.from_int_and_width(val, 32)),), cst.result) + for val in (k, n, m) + ] + # gemm # bypass simd and set all related values to 0 bypassSIMD = c1.result # bypass simd @@ -201,10 +210,16 @@ def _generate_setup_vals( elif isinstance(rescale := generic_op.body.block.first_op, kernel.RescaleOp): # extract and compute correct value for csr's based on kernel rescale op - # set k to 1 - knm.insert( - 0, ((cst := arith.Constant.from_int_and_width(1, 32),), cst.result) - ) + # set k and n to 1 + k = 1 + n = 1 + m = prod(x.data for x in op.stride_patterns.data[0].upper_bounds) + + knm: list = [ + (((cst := arith.Constant.from_int_and_width(val, 32)),), cst.result) + for val in (k, n, m) + ] + # simd bypassSIMD = c0.result subtractions = c0.result @@ -247,6 +262,38 @@ def _generate_setup_vals( loop_bound = arith.Constant.from_int_and_width(loop_bound, i32) ops_to_add.append(loop_bound) + elif isinstance(mac := generic_op.body.block.first_op, kernel.MacOp): + # compute knm: fix n = 1 + n = 1 + m = prod(x.data for x in op.stride_patterns.data[-1].upper_bounds) // n + k = prod(x.data for x in op.stride_patterns.data[0].upper_bounds) // m + + knm: list = [ + (((cst := arith.Constant.from_int_and_width(val, 32)),), cst.result) + for val in (k, n, m) + ] + + # gemm + # bypass simd and set all related values to 0 + bypassSIMD = c1.result # bypass simd + loop_bound = c0 + csr0 = c0.result + csr1 = c0.result + shift_vals = (c0.result for _ in range(2)) + mult_vals = (c0.result for _ in range(8)) + + # get zero points for gemm + zp_a = c0 + zp_b = c0 + + # bitwise and with 8b'11111111 to avoid the sign bits extending the 8-bit field + # when bitlist packing + ops_to_add.append(cst255 := arith.Constant.from_int_and_width(255, 32)) + + bitlist = list(pack_bitlist((zp_a, zp_b), [0, 8])) + ops_to_add.extend(bitlist) + subtractions = bitlist[-1].results[0] + else: raise NotImplementedError() diff --git a/compiler/accelerators/streamers/streamers.py b/compiler/accelerators/streamers/streamers.py index 4fd08562..36b3033c 100644 --- a/compiler/accelerators/streamers/streamers.py +++ b/compiler/accelerators/streamers/streamers.py @@ -17,6 +17,11 @@ class StreamerOpts(StrEnum): HasTranspose = "t" # Streamer with channel mask capabilities HasChannelMask = "c" + # Weird address remap thingy + HasAddressRemap = "r" + # Broadcasting + HasBroadcast = "b" + class StreamerFlag(StrEnum): diff --git a/compiler/dialects/kernel.py b/compiler/dialects/kernel.py index b9a61774..f3e443a3 100644 --- a/compiler/dialects/kernel.py +++ b/compiler/dialects/kernel.py @@ -101,7 +101,10 @@ def equivalent_region(self) -> Region: ) ) def equivalent_region(args: tuple[BlockArgument, ...]) -> None: - mul = arith.Muli(args[0], args[1]) + assert isinstance(output_type := args[2].type, IntegerType) + inp1 = arith.ExtSIOp(args[0], output_type) + inp2 = arith.ExtSIOp(args[1], output_type) + mul = arith.Muli(inp1, inp2) mac = arith.Addi(args[2], mul) linalg.YieldOp(mac) diff --git a/compiler/dialects/snax_stream.py b/compiler/dialects/snax_stream.py index 26eec2f6..74f22950 100644 --- a/compiler/dialects/snax_stream.py +++ b/compiler/dialects/snax_stream.py @@ -1,5 +1,6 @@ from collections.abc import Sequence +from typing_extensions import Self from xdsl.dialects.builtin import ArrayAttr, IndexType, IntAttr, StringAttr from xdsl.ir import ( Attribute, @@ -111,6 +112,32 @@ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]: ) return (ub, ts, ss) + def collapse_dimensions(self) -> Self: + """ + Collapses multiple perfect nested compatible for loops into 1 + for i = 0 .. 3 + for j = 0 .. 3 + a = 4 * i + j + + will turn into + for i = 0 .. 16 + a = 16 + """ + new_temporal_strides: list[int] = [] + new_upper_bounds: list[int] = [] + + for stride, bound in zip(self.temporal_strides.data, self.upper_bounds.data): + if bound.data == 1: + # unused dim + continue + if len(new_temporal_strides) > 0: + if new_temporal_strides[-1] * new_upper_bounds[-1] == stride.data: + new_upper_bounds[-1] *= bound.data + continue + new_upper_bounds.append(bound.data) + new_temporal_strides.append(stride.data) + return type(self)(new_upper_bounds, new_temporal_strides, self.spatial_strides) + @irdl_op_definition class StreamingRegionOp(IRDLOperation): diff --git a/compiler/dialects/stream.py b/compiler/dialects/stream.py index 56feaac2..1e251d9d 100644 --- a/compiler/dialects/stream.py +++ b/compiler/dialects/stream.py @@ -6,6 +6,9 @@ AnyShapedType, ArrayAttr, ContainerType, + IndexType, + IntAttr, + IntegerAttr, ShapedType, StringAttr, ) @@ -70,8 +73,7 @@ def get_element_type(self) -> _StreamTypeElement: return self.element_type -@irdl_op_definition -class StreamingRegionOp(IRDLOperation): +class StreamingRegionOpBase(IRDLOperation): """ An operation that creates streams from tensors or memrefs, which are only available to read from within the body of the operation. @@ -80,8 +82,6 @@ class StreamingRegionOp(IRDLOperation): via any other access means, including extraction (e.g.: memref.view). """ - name = "stream.streaming_region" - inputs = var_operand_def(AnyShapedType()) outputs = var_operand_def(AnyShapedType()) result_tensors = var_result_def() @@ -147,6 +147,55 @@ def get_static_pattern_bounds(self) -> Iterable[int]: tuple(self.get_static_shapes()), [] ) +@irdl_op_definition +class StreamingRegionOp(StreamingRegionOpBase): + + name = "stream.streaming_region" + + + +@irdl_op_definition +class ScheduleOp(IRDLOperation): + + name = "stream.schedule" + + inputs = var_operand_def(AnyShapedType()) + outputs = var_operand_def(AnyShapedType()) + result_tensors = var_result_def() + patterns = prop_def(ArrayAttr[AffineMapAttr]) + bounds = prop_def(ParameterDef[ArrayAttr[ArrayAttr[IntAttr]]]) + + body = region_def("single_block") + + accelerator = opt_prop_def(StringAttr) + + irdl_options = [AttrSizedOperandSegments(as_property=True)] + + def __init__( + self, + inputs: Sequence[SSAValue | Operation], + outputs: Sequence[SSAValue | Operation], + patterns: ArrayAttr[AffineMapAttr], + bounds: Sequence[Sequence[int]], + body: Region, + accelerator: str | StringAttr | None = None, + result_types: Sequence[Attribute] = (), + ) -> None: + if isinstance(accelerator, str): + accelerator = StringAttr(accelerator) + + bounds_attr = ArrayAttr(ArrayAttr(IntAttr(x) for x in y) for y in bounds) + super().__init__( + operands=[inputs, outputs], + regions=[body], + properties={ + "patterns": patterns, + "accelerator": accelerator, + "bounds": bounds_attr, + }, + result_types=[result_types], + ) + @irdl_op_definition class YieldOp(AbstractYieldOperation[Attribute]): @@ -197,6 +246,7 @@ def __init__( "stream", [ StreamingRegionOp, + ScheduleOp, GenericOp, YieldOp, ], diff --git a/compiler/ir/autoflow/__init__.py b/compiler/ir/autoflow/__init__.py new file mode 100644 index 00000000..c4e01d13 --- /dev/null +++ b/compiler/ir/autoflow/__init__.py @@ -0,0 +1 @@ +from .scheduler import * diff --git a/compiler/ir/autoflow/scheduler.py b/compiler/ir/autoflow/scheduler.py new file mode 100644 index 00000000..d73455e8 --- /dev/null +++ b/compiler/ir/autoflow/scheduler.py @@ -0,0 +1,130 @@ +from collections.abc import Iterable, Iterator + +from compiler.ir.stream import Schedule, Template +from compiler.ir.stream.access_pattern import SchedulePattern +from compiler.util.multiset import Multiset + + +def scheduler_backtrack(template: Template, schedule: Schedule, pure_output_stationary: bool, dim=1) -> Iterator[Schedule]: + # print(f'Running Scheduler Backtracking for dim = {dim}') + # print('Schedule:') + # print(schedule) + # print('Template:') + # print(template) + + if dim - 1 >= schedule.num_dims: + yield schedule + + N = schedule.num_dims - dim + K = template.num_dims - dim + + for n in range(N + 1): + # print('Checking the following schedule:') + # print(schedule) + + if K > 0: + schedule_check = schedule.disable_dims(N) + template_check = template.disable_dims(K) + else: + schedule_check = schedule.disable_dims(schedule.num_dims - template.num_dims) + template_check = template + + # print('Template check:') + # print(template_check) + # print('Schedule check:') + # print(schedule_check) + + if template_check.matches(schedule_check): + if dim > template.num_dims: + # programmatic acces, should be fine from now on + # extra check: constrain to output-stationary + i = schedule.num_dims - dim + ok = True + + # make sure to be at least one output stationary + if dim == template.num_dims + 1: + if schedule[-1].depends_on(i): + ok = False + + # extra check (1): constrain to pure output stationary + if pure_output_stationary: + if schedule[-1].depends_on(i): + # no further reductions can be allowed + while i >= 0: + if not schedule[-1].depends_on(i): + # print('not output stationary!') + ok = False + i -= 1 + + # extra check (2): make sure there is correct memory flexibility + def generate_one_list(n: int, i: int): + return [1 if j == i else 0 for j in range(n)] + + # only keep spatial dims: + for sp in schedule: + res = sp.disable_dims(schedule.num_dims - template.num_dims).pattern.eval( + [1] * template.num_dims, () + ) + # for these dimensions, more only one of the upper loops can have + # something not divisible by 8 + nbs_left = 1 + for idx in [index for index, value in enumerate(res) if value > 0]: + for i in range(schedule.num_dims - template.num_dims): + result = sp.pattern.eval(generate_one_list(schedule.num_dims, i), ())[idx] + if result % 8 != 0: + nbs_left -= 1 + if nbs_left < 0: + # print('not legal for memory!') + ok = False + + if ok: + pass + yield from scheduler_backtrack(template, schedule, pure_output_stationary, dim + 1) + + else: + # check bounds + template_bound = template[0].bounds[-dim] + assert template_bound # must have bound from now on + schedule_bound = schedule[0].bounds[-dim] + + if schedule_bound == template_bound: + pass + # print('perfect match, check behaviour') + elif schedule_bound < template_bound: + pass + # print('applying padding...') + # apply padding: + padded_schedule = schedule.pad_dim(N, template_bound) + # otherwise: + yield from scheduler_backtrack(template, schedule, pure_output_stationary, dim + 1) + elif schedule_bound > template_bound: + if schedule_bound % template_bound != 0: + pass + # print('imperfect factorization, no support yet') + padded_schedule = schedule.pad_dim(N, schedule_bound + (schedule_bound % template_bound)) + tiled_schedule = padded_schedule.tile_dim(N, template_bound) + # try again with padded schedule, but no increased dim + yield from scheduler_backtrack(template, tiled_schedule, pure_output_stationary, dim + 1) + else: + pass + # print('match, will apply tiling') + tiled_schedule = schedule.tile_dim(N, template_bound) + yield from scheduler_backtrack(template, tiled_schedule, pure_output_stationary, dim + 1) + else: + pass + # print('no match') + + # print('rotating...') + schedule = schedule.rotate(N + 1) + + +def scheduler(template: Template, schedule: Schedule, schedule_idx: int = 0, pure_output_stationary: bool = True) -> Schedule: + # prune away the 1-bounded dimensions: + schedule = schedule.clear_unused_dims() + + schedules = scheduler_backtrack(template, schedule, False) + + schedules = list(schedules) + + # match at schedule idx + return schedules[schedule_idx] diff --git a/compiler/ir/stream/__init__.py b/compiler/ir/stream/__init__.py index a7086edf..65459c57 100644 --- a/compiler/ir/stream/__init__.py +++ b/compiler/ir/stream/__init__.py @@ -1,2 +1,3 @@ from .access_pattern import * +from .optimizer import * from .scheduler import * diff --git a/compiler/ir/stream/access_pattern.py b/compiler/ir/stream/access_pattern.py index 03a3a1f2..a1bf15f5 100644 --- a/compiler/ir/stream/access_pattern.py +++ b/compiler/ir/stream/access_pattern.py @@ -3,10 +3,17 @@ from dataclasses import dataclass from typing import Generic -from typing_extensions import Self, TypeVar, deprecated, overload -from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap +from typing_extensions import Self, TypeVar, cast, deprecated, overload +from xdsl.ir.affine import ( + AffineBinaryOpExpr, + AffineConstantExpr, + AffineDimExpr, + AffineExpr, + AffineMap, +) from compiler.util.canonicalize_affine import canonicalize_map +from compiler.util.multiset import Multiset @dataclass(frozen=True) @@ -63,6 +70,28 @@ def disable_dims(self, dim: int) -> Self: ) return type(self)(self.bounds[dim:], new_pattern) + def depends_on(self, dim: int) -> bool: + """ + Returns if this access pattern depends on a given dimension. + + For example: + (d0, d1, d2) -> d0 + d1 + will return True for `dim` = 0 and 1, False for `dim` = 2 + """ + + # helper function do determine if affine expression depends on something + def expr_depends_on(expr: AffineExpr, dim: int): + if isinstance(expr, AffineBinaryOpExpr): + return expr_depends_on(expr.lhs, dim) or expr_depends_on(expr.rhs, dim) + elif isinstance(expr, AffineDimExpr): + return expr.position == dim + return False + + return any(expr_depends_on(result, dim) for result in self.pattern.results) + + def __str__(self) -> str: + return str(self.bounds) + str(self.pattern) + @dataclass(frozen=True) class SchedulePattern(AccessPattern): @@ -93,7 +122,6 @@ def rotate(self, dim: int) -> Self: (d0, d1, d2) -> 3 * d0 + 1 * d1 + 2 * d2 For `dim` = 2, will return: (d0, d1, d2) -> 2 * d0 + 1 * d1 + 3 * d2 - return AccessPattern() """ # Rotate in the following manner: @@ -109,6 +137,24 @@ def rotate(self, dim: int) -> Self: ) return type(self)(new_bounds, new_pattern) + def reorder(self, new_order: Sequence[int]) -> Self: + """ + Returns a new SchedulePattern with dimensions reordered. + + For example: + (d0, d1, d2) -> 1 * d0 + 2 * d1 + 3 * d2 + For `new_order` = (0, 2, 1), will return: + (d0, d1, d2) -> 1 * d0 + 3 * d1 + 2 * d2 + For `new_order` = (1, 0, 2), will return: + (d0, d1, d2) -> 2 * d0 + 1 * d1 + 3 * d2 + """ + new_dims = [AffineDimExpr(new_order[i]) for i in new_order] + new_pattern = self.pattern.replace_dims_and_symbols( + new_dims, [], self.num_dims, 0 + ) + new_bounds = [self.bounds[i] for i in new_order] + return type(self)(new_bounds, new_pattern) + def tile_dim(self, dim: int, template_bound: int) -> Self: """ Returns a new access pattern with the `dim` dimension split up into two @@ -146,6 +192,15 @@ def tile_dim(self, dim: int, template_bound: int) -> Self: return type(self)(new_bounds, new_pattern) + def pad_dim(self, dim: int, template_bound: int) -> Self: + """ + Returns a new schedule pattern with the `dim` dimension padded up + to the template dimension + """ + new_bounds = list(self.bounds) + new_bounds[dim] = template_bound + return type(self)(new_bounds, self.pattern) + def add_dim(self) -> Self: """ Returns a new schedule pattern with an extra empty dimension inserted. @@ -171,6 +226,12 @@ class TemplatePattern(AccessPattern): Template pattern is a pattern for an accelerator template. Templates should not be transformed through either tiling/rotating/others. + + Template bounds have some special properties: + - bound == None -> bound can be programmed to anything from 0 to inf + - bound <= 0: undefined behaviour + - bound == 1: bound can be anything, but should be fixed to 1 (used for stationarity) + - bound <= 1: bound is fixed """ def __init__(self, bounds: Sequence[int | None], pattern: AffineMap): @@ -181,11 +242,9 @@ def matches(self, sp: SchedulePattern): Check if a given schedule pattern matches this template pattern. """ - if sp.num_dims != self.num_dims: - return False - if sp.pattern != self.pattern: - return False - return True + self_pattern = canonicalize_map(self.pattern, extreme=True) + schedule_pattern = canonicalize_map(sp.pattern, extreme=True) + return Multiset(self_pattern.results).is_subset(Multiset(schedule_pattern.results)) P = TypeVar("P", bound=AccessPattern) @@ -223,7 +282,9 @@ def __eq__(self, other: object) -> bool: return self._patterns == other._patterns @property - @deprecated("only valid in trivial cases") + @deprecated( + "only valid in trivial cases (hindsight: no, this should always be valid)" + ) def num_dims(self) -> int: return self[0].num_dims @@ -231,6 +292,21 @@ def num_dims(self) -> int: def max_dim(self) -> int: return max(pattern.num_dims for pattern in self._patterns) + @property + def bounds(self) -> list[int | None]: + """ + Returns the combined bounds of all patterns in this collection + """ + result: list[int | None] = [] + for i in range(self.num_dims): + combined_bounds = [pattern.bounds[i] for pattern in self._patterns] + if any([x is None for x in combined_bounds]): + result.append(None) + else: + combined_bounds = cast(list[int], combined_bounds) + result.append(max(combined_bounds)) + return result + def disable_dims(self, dim: int) -> Self: return type(self)(sp.disable_dims(dim) for sp in self) @@ -240,7 +316,7 @@ def clear_unused_dims(self, bounds: tuple[int] | None = None) -> Self: Optionally, specify custom bounds. """ if bounds is None: - pattern_bounds = self._patterns[0].bounds + pattern_bounds = tuple(self.bounds) else: pattern_bounds = bounds unused_dims = tuple(i for i, bound in enumerate(pattern_bounds) if bound == 1) @@ -254,7 +330,11 @@ def clear_unused_dims(self, bounds: tuple[int] | None = None) -> Self: unused_counter += 1 return type(self)( type(self._patterns[0])( - tuple(bound for bound in pattern_bounds if bound != 1), + tuple( + bound + for bound, pattern_bound in zip(sp.bounds, pattern_bounds) + if pattern_bound != 1 + ), sp.pattern.replace_dims_and_symbols( dim_substitutions, [], self.num_dims - unused_counter, 0 ), @@ -262,6 +342,9 @@ def clear_unused_dims(self, bounds: tuple[int] | None = None) -> Self: for sp in self ) + def __str__(self) -> str: + return "\n".join(str(pattern) for pattern in self._patterns) + class Schedule(PatternCollection[SchedulePattern]): """ @@ -271,9 +354,15 @@ class Schedule(PatternCollection[SchedulePattern]): def rotate(self, dim: int) -> Self: return type(self)(sp.rotate(dim) for sp in self) + def reorder(self, new_order: Sequence[int]) -> Self: + return type(self)(sp.reorder(new_order) for sp in self) + def tile_dim(self, dim: int, template_bound: int) -> Self: return type(self)(sp.tile_dim(dim, template_bound) for sp in self) + def pad_dim(self, dim: int, template_bound: int) -> Self: + return type(self)(sp.pad_dim(dim, template_bound) for sp in self) + def add_dim(self) -> Self: return type(self)(sp.add_dim() for sp in self) diff --git a/compiler/ir/stream/optimizer.py b/compiler/ir/stream/optimizer.py new file mode 100644 index 00000000..5f01e1b9 --- /dev/null +++ b/compiler/ir/stream/optimizer.py @@ -0,0 +1,23 @@ +from compiler.ir.stream import Schedule, Template + + +def optimizer(template: Template, schedule: Schedule) -> Schedule: + # Optimization 1: put all free-moving reduction dims as the most inner loops + + free_reduction_dims = [] + + # free-moving dims are dims that appear in schedule but not template + for i in range(schedule.num_dims - template.num_dims): + if not schedule[-1].depends_on(i): + free_reduction_dims.append(i) + + # apply reordering: + fixed_dims = len([x for x in template[-1].bounds if x is not None]) + new_order = [x for x in range(schedule.num_dims) if x not in free_reduction_dims] + new_order[-fixed_dims:-fixed_dims] = free_reduction_dims + schedule = schedule.reorder(new_order) + + # Optimization 2: clear all unused dims of the pattern (= with bound 1) + schedule = schedule.clear_unused_dims() + + return schedule diff --git a/compiler/ir/stream/scheduler.py b/compiler/ir/stream/scheduler.py index 45303fed..062b4aa8 100644 --- a/compiler/ir/stream/scheduler.py +++ b/compiler/ir/stream/scheduler.py @@ -9,35 +9,75 @@ def scheduler(template: Template, schedule: Schedule) -> Schedule: schedule_dim = schedule.num_dims - i - 1 match = False + # print("Template:") + # print(template[0].pattern) + # print(template[1].pattern) + # print(template[2].pattern) + # print("Schedule:") + # print(schedule[0].pattern) + # print(schedule[1].pattern) + # print(schedule[2].pattern) + # maximum number of rotations for _ in range(schedule_dim + 1): # check if there is a match template_check = template.disable_dims(template_dim) schedule_check = schedule.disable_dims(schedule_dim) - if template_check.matches(schedule_check): + # print("Template check:") + # print(template_check[0].pattern) + # print(template_check[1].pattern) + # print(template_check[2].pattern) + # print("Schedule check:") + # print(schedule_check[0].pattern) + # print(schedule_check[1].pattern) + # print(schedule_check[2].pattern) + + # print(template_check.matches(schedule_check)) + + # if template_check.matches(schedule_check): + # match = True + # break + + if not template_check.matches(schedule_check): + schedule = schedule.rotate(schedule_dim + 1) + continue + + # # else rotate the for loops + # schedule = schedule.rotate(schedule_dim + 1) + + # now, check bounds and design potential transformation map + if not (template_bound := template[0].bounds[template_dim]): + # nothing to worry about, continue to next dim match = True break - # else rotate the for loops - schedule = schedule.rotate(schedule_dim + 1) + schedule_bound = schedule[0].bounds[schedule_dim] + + if schedule_bound < template_bound: + # need to apply padding, but not supported yet. + # try and find other option + schedule = schedule.rotate(schedule_dim + 1) + continue + # raise NotImplementedError("padding not supported") + elif schedule_bound >= template_bound: + # need to split up the schedule + assert schedule_bound % template_bound == 0 + schedule = schedule.tile_dim(schedule_dim, template_bound) + # nice, continue + match = True + break if not match: raise RuntimeError("failed to match template and schedule") - # now, check bounds and design potential transformation map - if not (template_bound := template[0].bounds[template_dim]): - # nothing to worry about, continue to next dim - continue - - schedule_bound = schedule[0].bounds[schedule_dim] - - if schedule_bound < template_bound: - # need to apply padding - raise NotImplementedError("padding not supported") - elif schedule_bound >= template_bound: - # need to split up the schedule - assert schedule_bound % template_bound == 0 - schedule = schedule.tile_dim(schedule_dim, template_bound) + # print("Final Template:") + # print(template[0].pattern) + # print(template[1].pattern) + # print(template[2].pattern) + # print("Final Schedule:") + # print(schedule[0].pattern) + # print(schedule[1].pattern) + # print(schedule[2].pattern) return schedule diff --git a/compiler/ir/tsl/tiled_stride.py b/compiler/ir/tsl/tiled_stride.py index fbf572cf..ddc19ba4 100644 --- a/compiler/ir/tsl/tiled_stride.py +++ b/compiler/ir/tsl/tiled_stride.py @@ -63,6 +63,20 @@ def __str__(self) -> str: ) return f"[{bounds}] -> ({strides})" + def simplify(self): + strides = [] + for stride in reversed(self.strides): + if not strides: + strides.append(stride) + elif stride.bound == 1: + continue + elif strides[-1].bound * strides[-1].step == stride.step: + strides[-1] = Stride(strides[-1].step, strides[-1].bound * stride.bound) + else: + strides.insert(0, stride) + + return TiledStride(strides) + def __iter__(self) -> Iterator[tuple[int, Stride]]: """Returns an iterator of (depth, stride) over all the strides of the Tiled Stride diff --git a/compiler/ir/tsl/tiled_strided_layout.py b/compiler/ir/tsl/tiled_strided_layout.py index c14d4764..b129fe5b 100644 --- a/compiler/ir/tsl/tiled_strided_layout.py +++ b/compiler/ir/tsl/tiled_strided_layout.py @@ -19,7 +19,7 @@ class TiledStridedLayout: """ tstrides: list[TiledStride] - offset: int | None = 0 + offset: int = 0 @staticmethod def from_strides( @@ -76,6 +76,9 @@ def get_stride(self, dim: int, depth: int) -> Stride: the Tiled Strided Layout""" return self.tstrides[dim].strides[depth] + def simplify(self): + return TiledStridedLayout([tstride.simplify() for tstride in self.tstrides], self.offset) + def all_values(self) -> np.ndarray: """ Returns a numpy array containing all the elements in the iteration space. diff --git a/compiler/tools/snax_opt_main.py b/compiler/tools/snax_opt_main.py index 1e50f0f0..31d8a8b6 100644 --- a/compiler/tools/snax_opt_main.py +++ b/compiler/tools/snax_opt_main.py @@ -15,6 +15,8 @@ from compiler.transforms.accfg_dedup import AccfgDeduplicate from compiler.transforms.accfg_insert_resets import InsertResetsPass from compiler.transforms.alloc_to_global import AllocToGlobalPass +from compiler.transforms.autoflow_layout_resolution import AutoflowLayoutResolutionPass, AutoflowLayoutResolutionPattern +from compiler.transforms.autoflow_scheduler import AutoflowSchedulerPass from compiler.transforms.clear_memory_space import ClearMemorySpace from compiler.transforms.convert_accfg_to_csr import ConvertAccfgToCsrPass from compiler.transforms.convert_kernel_to_linalg import ConvertKernelToLinalg @@ -46,6 +48,7 @@ from compiler.transforms.test.debug_to_func import DebugToFuncPass from compiler.transforms.test.insert_debugs import InsertDebugPass from compiler.transforms.test.test_add_mcycle_around_launch import AddMcycleAroundLaunch +from compiler.transforms.test.test_remove_copies import RemoveCopiesPass from compiler.transforms.test_add_mcycle_around_loop import AddMcycleAroundLoopPass from compiler.transforms.test_remove_memref_copy import RemoveMemrefCopyPass @@ -126,6 +129,9 @@ def __init__( super().register_pass(SnaxBufferize.name, lambda: SnaxBufferize) super().register_pass(FuseStreamingRegions.name, lambda: FuseStreamingRegions) super().register_pass(AllocToGlobalPass.name, lambda: AllocToGlobalPass) + super().register_pass(AutoflowSchedulerPass.name, lambda: AutoflowSchedulerPass) + super().register_pass(AutoflowLayoutResolutionPass.name, lambda: AutoflowLayoutResolutionPass) + super().register_pass(RemoveCopiesPass.name, lambda: RemoveCopiesPass) # arg handling arg_parser = argparse.ArgumentParser(description=description) diff --git a/compiler/transforms/autoflow_layout_resolution.py b/compiler/transforms/autoflow_layout_resolution.py new file mode 100644 index 00000000..6dd2da82 --- /dev/null +++ b/compiler/transforms/autoflow_layout_resolution.py @@ -0,0 +1,313 @@ + +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import arith, builtin, memref +from xdsl.dialects.builtin import AffineMapAttr, ArrayAttr, MemRefType +from xdsl.ir import Operation +from xdsl.ir.affine import AffineDimExpr, AffineMap +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + +from compiler.accelerators import find_accelerator_op +from compiler.accelerators.registry import AcceleratorRegistry +from compiler.accelerators.snax import SNAXStreamer +from compiler.dialects import snax_stream, stream +from compiler.dialects.snax import StreamerConfigurationAttr +from compiler.ir.autoflow import scheduler +from compiler.ir.stream import Schedule, SchedulePattern, optimizer +from compiler.ir.stream.access_pattern import Template, TemplatePattern + + +@dataclass +class AutoflowLayoutResolutionPattern(RewritePattern): + """ + A pass to convert streaming region operations to snax stream. + + First, the operation is scheduled to an accelerator with the schedule_streams algorithm. + + After this, this boils down to combining the data access patterns of a stream op (operation -> data), + with a certain data layout: an affine map from (data -> memory) into a mapping (operation -> memory). + + This takes the form of a snax_stream access pattern, mapping (operation -> memory) + which, in hardware, is realized by the Streamers. + + Current restrictions: + We are only handling default memory layouts for now (NoneAttr) + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: stream.ScheduleOp, rewriter: PatternRewriter + ): + # Handle only stream ops dispatched to an accelerator: + if op.accelerator is None: + return + + # Go and fetch the accelerator op + accelerator_str = op.accelerator.data + acc_op = find_accelerator_op(op, accelerator_str) + + if not acc_op: + raise RuntimeError("AcceleratorOp not found!") + + if "streamer_config" not in acc_op.attributes: + raise RuntimeError("Streamer interface not found for given accelerator op") + streamer_config = acc_op.attributes["streamer_config"] + assert isinstance(streamer_config, StreamerConfigurationAttr) + + # get template and template_bounds + accelerator_type = AcceleratorRegistry().get_acc_info(acc_op) + assert issubclass(accelerator_type, SNAXStreamer) + + + # overwrite template for gemmx: + m, n, k = (AffineDimExpr(i) for i in range(3)) + template = [ + AffineMap(3, 0, (m, k)), + AffineMap(3, 0, (k, n)), + AffineMap(3, 0, (m, n)), + ] + template_bounds = (8, 8, 8) + + template = Template(TemplatePattern(template_bounds, tp) for tp in template) + + + + # Make sure the operands are memrefs + for memref_operand in op.operands: + if not isinstance(memref_operand.type, builtin.MemRefType): + return + + # recreate schedule from op + schedule = Schedule( + SchedulePattern( + bounds=[x.data for x in bounds.data], + pattern=pattern.data + ) + for pattern, bounds in zip(op.patterns.data, op.bounds.data) + ) + + # assess stationarity and implement it + + # for now, only weight stationarity is functional + dim = schedule.num_dims - template.num_dims - 1 + + new_output_bounds = list(schedule[-1].bounds) + + for dim in reversed(range(schedule.num_dims - template.num_dims)): + # only check output stationarity + if not schedule[-1].depends_on(dim): + new_output_bounds[dim] = 1 + else: + break + + schedules = [pattern for pattern in schedule] + schedules[-1] = SchedulePattern(new_output_bounds, schedules[-1].pattern) + + + # do not optimize for now + # schedule = optimizer(template, schedule) + + # We are now ready to convert the stream access patterns into snax stride patterns + # construct the strided patterns for SNAX Streamers + + snax_stride_patterns: list[snax_stream.StridePattern] = [] + + # small function to generate a list of n zeros with the i-th element 1 + # for example n = 4, i = 1 -> [0, 1, 0, 0] + def generate_one_list(n: int, i: int): + return [1 if j == i else 0 for j in range(n)] + + # Do this for every operand: + for operand in range(len(op.operands)): + # Mapping from data to memory: + assert isinstance(memref_type := op.operands[operand].type, MemRefType) + + # Mapping from data to memory: + data_mem_map: AffineMap = memref_type.get_affine_map_in_bytes() + + # Mapping from access to data: + access_data_map: AffineMap = schedules[operand].pattern + + # Mapping from access to memory: + access_mem_map: AffineMap = data_mem_map.compose(access_data_map) + + # Make sure no symbols are used (not supported yet) + if access_mem_map.num_symbols != 0: + raise RuntimeError( + "Access patterns with symbols are not supported yet." + ) + + # Create iterator for all dimensions of the access_mem_map that returns (stride, bound) + access_iter = iter( + ( + access_mem_map.eval( + generate_one_list(access_mem_map.num_dims, i), () + )[0], + schedules[operand].bounds[i], + ) + for i in reversed(range(access_mem_map.num_dims)) + ) + + # Fetch the first stride + stride, bound = next(access_iter) + + temporal_strides: list[int] = [] + spatial_strides: list[int] = [] + upper_bounds: list[int] = [] + + # fill up all spatial strides + for _ in [x for x in template[0].bounds if x is not None]: + # FIXME: provide more general solution for new spatial streamer_config + # configuration, this works in all current cases and layouts but is far from generally correct. + spatial_strides = [8] + stride, bound = next(access_iter, (None, None)) + # print(f"operand {operand} stride {stride} bound {bound}") + + # remaining are temporal strides + while stride is not None and bound is not None: + temporal_strides.append(stride) + upper_bounds.append(bound) + stride, bound = next(access_iter, (None, None)) + + # create the stride pattern for this operand + snax_stride_pattern = snax_stream.StridePattern( + upper_bounds=upper_bounds, + temporal_strides=temporal_strides, + spatial_strides=spatial_strides, + ) + snax_stride_patterns.append(snax_stride_pattern) + + # get base addresses of the streaming region ops + # TODO: generalize and fix for offsets + + new_inputs: list[Operation] = [ + memref.ExtractAlignedPointerAsIndexOp.get(input) for input in op.inputs + ] + new_outputs = [ + memref.ExtractAlignedPointerAsIndexOp.get(output) for output in op.outputs + ] + + # TODO: what is still required is a better system for the unused operands + # of snax_gemmx / other accelerators. this now fills in empty/zero patterns for the unused operands. + + if acc_op.name_prop.root_reference.data == "snax_gemmx": + empty_pattern = snax_stream.StridePattern( + upper_bounds=[0] * 3, temporal_strides=[0] * 3, spatial_strides=[0] + ) + if len(snax_stride_patterns) == 3: + # matmul + + # for a matmul, the 8-bit output port D8 and the bias in put C + # are unused, so we create empty patterns for them here. + + # insert empty patterns for D8 and zero pattern for C + snax_stride_patterns.insert(2, empty_pattern) + new_inputs.append( + memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1]) + ) + + # insert zero pattern for C, using the same pattern as D32 but pointing to zero + # this way, the bias used by the gemm is just a bunch of zeros + snax_stride_patterns.insert( + 3, + snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[3].upper_bounds, + temporal_strides=snax_stride_patterns[3].temporal_strides, + spatial_strides=[8], + ), + ) + + # convs: point C to D32 + new_inputs.append(memref.ExtractAlignedPointerAsIndexOp.get(op.outputs[0])) + # new_inputs.append( + # # zero pointer will generate 0 values + # arith.Constant.from_int_and_width(0, builtin.IndexType()) + # ) + elif len(snax_stride_patterns) == 4: + # gemm + # + # for a gemm, the 8bit-output port D8 are unused, so we create + # empty patterns for them here + snax_stride_patterns.insert(2, empty_pattern) + new_inputs.insert( + 2, memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1]) + ) + + else: + # simd + # to calculate only simd, we calculate the result + # of D8 = rescale(AxB + C) + # create zero patterns for A and B such that D8 = rescale(C) + # create empty pattern for D32 + # do not use new outputs + new_inputs.append(new_outputs.pop()) + + zero_pattern = snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[0].upper_bounds, + temporal_strides=[0] * len(snax_stride_patterns[0].upper_bounds), + spatial_strides=[8], + ) + + # read zeros from tcdm (must make sure there are zeros at these addresses) + # in the new streamer this can be fixed with byte masking + snax_stride_patterns.insert(0, zero_pattern) + new_inputs.insert( + 0, + # zero pointer will generate 0 values + arith.Constant.from_int_and_width(0, builtin.IndexType()), + ) + snax_stride_patterns.insert(1, zero_pattern) + new_inputs.insert( + 1, + # zero pointer will generate 0 values + arith.Constant.from_int_and_width(0, builtin.IndexType()), + ) + + # flip D8 and C such that they are in the right order + snax_stride_patterns.append(snax_stride_patterns.pop(2)) + new_inputs.append(new_inputs.pop(2)) + + # empty pattern for D32 + snax_stride_patterns.append(empty_pattern) + # dummy base pointer for D32 + new_inputs.append( + memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1]) + ) + + # make last spatial stride patterns 2d + snax_stride_patterns[-2] = snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[-2].upper_bounds, + temporal_strides=snax_stride_patterns[-2].temporal_strides, + spatial_strides=[8, 32], + ) + snax_stride_patterns[-1] = snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[-1].upper_bounds, + temporal_strides=snax_stride_patterns[-1].temporal_strides, + spatial_strides=[8, 32], + ) + + # now create snax_streaming region op + new_op = snax_stream.StreamingRegionOp( + inputs=new_inputs, + outputs=new_outputs, + stride_patterns=snax_stride_patterns, + accelerator=accelerator_str, + body=rewriter.move_region_contents_to_new_regions(op.body), + ) + + rewriter.replace_matched_op([*new_inputs, *new_outputs, new_op], new_op.results) + + +@dataclass(frozen=True) +class AutoflowLayoutResolutionPass(ModulePass): + name = "autoflow-layout-resolution" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker(AutoflowLayoutResolutionPattern()).rewrite_module(op) diff --git a/compiler/transforms/autoflow_scheduler.py b/compiler/transforms/autoflow_scheduler.py new file mode 100644 index 00000000..72a7aab8 --- /dev/null +++ b/compiler/transforms/autoflow_scheduler.py @@ -0,0 +1,295 @@ +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import arith, builtin, memref +from xdsl.dialects.builtin import AffineMapAttr, ArrayAttr, MemRefType +from xdsl.ir import Operation +from xdsl.ir.affine import AffineDimExpr, AffineMap +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + +from compiler.accelerators import find_accelerator_op +from compiler.accelerators.registry import AcceleratorRegistry +from compiler.accelerators.snax import SNAXStreamer +from compiler.dialects import snax_stream, stream +from compiler.dialects.snax import StreamerConfigurationAttr +from compiler.ir.autoflow import scheduler +from compiler.ir.stream import Schedule, SchedulePattern, optimizer +from compiler.ir.stream.access_pattern import Template, TemplatePattern + + +@dataclass +class AutoflowSchedulerPattern(RewritePattern): + schedule_idx: int = 0 + pure_output_stationary: bool = True + + """ + A pass to convert streaming region operations to snax stream. + + First, the operation is scheduled to an accelerator with the schedule_streams algorithm. + + After this, this boils down to combining the data access patterns of a stream op (operation -> data), + with a certain data layout: an affine map from (data -> memory) into a mapping (operation -> memory). + + This takes the form of a snax_stream access pattern, mapping (operation -> memory) + which, in hardware, is realized by the Streamers. + + Current restrictions: + We are only handling default memory layouts for now (NoneAttr) + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: stream.StreamingRegionOp, rewriter: PatternRewriter): + # Handle only stream ops dispatched to an accelerator: + if op.accelerator is None: + return + + # Go and fetch the accelerator op + accelerator_str = op.accelerator.data + acc_op = find_accelerator_op(op, accelerator_str) + + if not acc_op: + raise RuntimeError("AcceleratorOp not found!") + + if "streamer_config" not in acc_op.attributes: + raise RuntimeError("Streamer interface not found for given accelerator op") + streamer_config = acc_op.attributes["streamer_config"] + assert isinstance(streamer_config, StreamerConfigurationAttr) + + # get template and template_bounds + accelerator_type = AcceleratorRegistry().get_acc_info(acc_op) + assert issubclass(accelerator_type, SNAXStreamer) + + template = accelerator_type.get_template(op) + + # Make sure the operands are memrefs + for memref_operand in op.operands: + if not isinstance(memref_operand.type, builtin.MemRefType): + return + + # First, run the stream scheduling algorithm + schedule_bounds = tuple(op.get_static_pattern_bounds()) + schedule = Schedule(SchedulePattern(schedule_bounds, pattern.data) for pattern in op.patterns.data) + + # overwrite template for gemmx: + m, n, k = (AffineDimExpr(i) for i in range(3)) + template = [ + AffineMap(3, 0, (m, k)), + AffineMap(3, 0, (k, n)), + AffineMap(3, 0, (m, n)), + ] + template_bounds = (8, 8, 8) + + template = Template(TemplatePattern(template_bounds, tp) for tp in template) + + schedule = scheduler(template, schedule, self.schedule_idx, self.pure_output_stationary) + + # replace by stream schedule op + new_op = stream.ScheduleOp( + op.inputs, + op.outputs, + ArrayAttr([AffineMapAttr(s.pattern) for s in schedule]), + [s.bounds for s in schedule], + rewriter.move_region_contents_to_new_regions(op.body), + op.accelerator, + op.result_types, + ) + + rewriter.replace_matched_op(new_op) + + return + + # FIXME: for stationary bounds, this is hardcoded + schedules = [pattern for pattern in schedule] + for i in range(2, len(schedule)): + new_bounds = list(schedules[i].bounds) + if len(new_bounds) > 2: + new_bounds[-4] = 1 + if len(new_bounds) == 10: + # specific conv example + new_bounds[1] = 1 + new_bounds[2] = 1 + schedules[i] = SchedulePattern(new_bounds, schedules[i].pattern) + schedule = Schedule(schedules) + + schedule = optimizer(template, schedule) + + # We are now ready to convert the stream access patterns into snax stride patterns + # construct the strided patterns for SNAX Streamers + + snax_stride_patterns: list[snax_stream.StridePattern] = [] + + # small function to generate a list of n zeros with the i-th element 1 + # for example n = 4, i = 1 -> [0, 1, 0, 0] + def generate_one_list(n: int, i: int): + return [1 if j == i else 0 for j in range(n)] + + # Do this for every operand: + for operand in range(len(op.operands)): + # Mapping from data to memory: + assert isinstance(memref_type := op.operands[operand].type, MemRefType) + + # Mapping from data to memory: + data_mem_map: AffineMap = memref_type.get_affine_map_in_bytes() + + # Mapping from access to data: + access_data_map: AffineMap = schedule[operand].pattern + + # Mapping from access to memory: + access_mem_map: AffineMap = data_mem_map.compose(access_data_map) + + # Make sure no symbols are used (not supported yet) + if access_mem_map.num_symbols != 0: + raise RuntimeError("Access patterns with symbols are not supported yet.") + + # Create iterator for all dimensions of the access_mem_map that returns (stride, bound) + access_iter = iter( + ( + access_mem_map.eval(generate_one_list(access_mem_map.num_dims, i), ())[0], + schedule[operand].bounds[i], + ) + for i in reversed(range(access_mem_map.num_dims)) + ) + + # Fetch the first stride + stride, bound = next(access_iter) + + temporal_strides: list[int] = [] + spatial_strides: list[int] = [] + upper_bounds: list[int] = [] + + # fill up all spatial strides + for _ in [x for x in template[0].bounds if x is not None]: + # FIXME: provide more general solution for new spatial streamer_config + # configuration, this works in all current cases and layouts but is far from generally correct. + spatial_strides = [8] + stride, bound = next(access_iter, (None, None)) + + # remaining are temporal strides + while stride is not None and bound is not None: + temporal_strides.append(stride) + upper_bounds.append(bound) + stride, bound = next(access_iter, (None, None)) + + # create the stride pattern for this operand + snax_stride_pattern = snax_stream.StridePattern( + upper_bounds=upper_bounds, + temporal_strides=temporal_strides, + spatial_strides=spatial_strides, + ).collapse_dimensions() + snax_stride_patterns.append(snax_stride_pattern) + + # get base addresses of the streaming region ops + # TODO: generalize and fix for offsets + + new_inputs: list[Operation] = [memref.ExtractAlignedPointerAsIndexOp.get(input) for input in op.inputs] + new_outputs = [memref.ExtractAlignedPointerAsIndexOp.get(output) for output in op.outputs] + + # TODO: what is still required is a better system for the unused operands + # of snax_gemmx / other accelerators. this now fills in empty/zero patterns for the unused operands. + + if acc_op.name_prop.root_reference.data == "snax_gemmx": + empty_pattern = snax_stream.StridePattern( + upper_bounds=[0] * 3, temporal_strides=[0] * 3, spatial_strides=[0] + ) + if len(snax_stride_patterns) == 3: + # matmul + + # for a matmul, the 8-bit output port D8 and the bias in put C + # are unused, so we create empty patterns for them here. + + # insert empty patterns for D8 and zero pattern for C + snax_stride_patterns.insert(2, empty_pattern) + new_inputs.append(memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1])) + + # insert zero pattern for C, using the same pattern as D32 but pointing to zero + # this way, the bias used by the gemm is just a bunch of zeros + snax_stride_patterns.insert( + 3, + snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[3].upper_bounds, + temporal_strides=[0] * len(snax_stride_patterns[3].upper_bounds), + spatial_strides=[8], + ), + ) + + # point C to c0 + new_inputs.append( + # zero pointer will generate 0 values + arith.Constant.from_int_and_width(0, builtin.IndexType()) + ) + elif len(snax_stride_patterns) == 4: + # gemm + # + # for a gemm, the 8bit-output port D8 are unused, so we create + # empty patterns for them here + snax_stride_patterns.insert(2, empty_pattern) + new_inputs.insert(2, memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1])) + + else: + # simd + # to calculate only simd, we calculate the result + # of D8 = rescale(AxB + C) + # create zero patterns for A and B such that D8 = rescale(C) + # create empty pattern for D32 + # do not use new outputs + new_inputs.append(new_outputs.pop()) + + zero_pattern = snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[0].upper_bounds, + temporal_strides=[0] * len(snax_stride_patterns[0].upper_bounds), + spatial_strides=[8], + ) + + # read zeros from tcdm (must make sure there are zeros at these addresses) + # in the new streamer this can be fixed with byte masking + snax_stride_patterns.insert(0, zero_pattern) + new_inputs.insert( + 0, + # zero pointer will generate 0 values + arith.Constant.from_int_and_width(0, builtin.IndexType()), + ) + snax_stride_patterns.insert(1, zero_pattern) + new_inputs.insert( + 1, + # zero pointer will generate 0 values + arith.Constant.from_int_and_width(0, builtin.IndexType()), + ) + + # flip D8 and C such that they are in the right order + snax_stride_patterns.append(snax_stride_patterns.pop(2)) + new_inputs.append(new_inputs.pop(2)) + + # empty pattern for D32 + snax_stride_patterns.append(empty_pattern) + # dummy base pointer for D32 + new_inputs.append(memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1])) + + # now create snax_streaming region op + new_op = snax_stream.StreamingRegionOp( + inputs=new_inputs, + outputs=new_outputs, + stride_patterns=snax_stride_patterns, + accelerator=accelerator_str, + body=rewriter.move_region_contents_to_new_regions(op.body), + ) + + rewriter.replace_matched_op([*new_inputs, *new_outputs, new_op], new_op.results) + + +@dataclass(frozen=True) +class AutoflowSchedulerPass(ModulePass): + name = "autoflow-scheduler" + + schedule_idx: int = 0 + pure_output_stationary: bool = True + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker(AutoflowSchedulerPattern(self.schedule_idx, self.pure_output_stationary)).rewrite_module( + op + ) diff --git a/compiler/transforms/convert_stream_to_snax_stream.py b/compiler/transforms/convert_stream_to_snax_stream.py index 27616ad1..6dae4e94 100644 --- a/compiler/transforms/convert_stream_to_snax_stream.py +++ b/compiler/transforms/convert_stream_to_snax_stream.py @@ -18,7 +18,7 @@ from compiler.accelerators.snax import SNAXStreamer from compiler.dialects import snax_stream, stream from compiler.dialects.snax import StreamerConfigurationAttr -from compiler.ir.stream import Schedule, SchedulePattern, scheduler +from compiler.ir.stream import Schedule, SchedulePattern, optimizer, scheduler @dataclass @@ -77,6 +77,21 @@ def match_and_rewrite( ) schedule = scheduler(template, schedule) + # FIXME: for stationary bounds, this is hardcoded + schedules = [pattern for pattern in schedule] + for i in range(2, len(schedule)): + new_bounds = list(schedules[i].bounds) + if len(new_bounds) > 2: + new_bounds[-4] = 1 + if len(new_bounds) == 10: + # specific conv example + new_bounds[1] = 1 + new_bounds[2] = 1 + schedules[i] = SchedulePattern(new_bounds, schedules[i].pattern) + schedule = Schedule(schedules) + + schedule = optimizer(template, schedule) + # We are now ready to convert the stream access patterns into snax stride patterns # construct the strided patterns for SNAX Streamers @@ -143,7 +158,7 @@ def generate_one_list(n: int, i: int): upper_bounds=upper_bounds, temporal_strides=temporal_strides, spatial_strides=spatial_strides, - ) + ).collapse_dimensions() snax_stride_patterns.append(snax_stride_pattern) # get base addresses of the streaming region ops @@ -243,6 +258,18 @@ def generate_one_list(n: int, i: int): memref.ExtractAlignedPointerAsIndexOp.get(op.inputs[-1]) ) + # make last spatial stride patterns 2d + snax_stride_patterns[-2] = snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[-2].upper_bounds, + temporal_strides=snax_stride_patterns[-2].temporal_strides, + spatial_strides=[8, 32], + ) + snax_stride_patterns[-1] = snax_stream.StridePattern( + upper_bounds=snax_stride_patterns[-1].upper_bounds, + temporal_strides=snax_stride_patterns[-1].temporal_strides, + spatial_strides=[8, 32], + ) + # now create snax_streaming region op new_op = snax_stream.StreamingRegionOp( inputs=new_inputs, diff --git a/compiler/transforms/realize_memref_casts.py b/compiler/transforms/realize_memref_casts.py index 9fda71bc..5d507e86 100644 --- a/compiler/transforms/realize_memref_casts.py +++ b/compiler/transforms/realize_memref_casts.py @@ -100,6 +100,8 @@ def match_and_rewrite( is_input = op.results[0] in use_op.inputs elif isinstance(use_op, stream.StreamingRegionOp): is_input = op.results[0] in use_op.inputs + elif isinstance(use_op, stream.ScheduleOp): + is_input = op.results[0] in use_op.inputs else: is_input = True if is_input: @@ -119,6 +121,8 @@ def match_and_rewrite( is_output = op.results[0] in use_op.outputs elif isinstance(use_op, stream.StreamingRegionOp): is_output = op.results[0] in use_op.outputs + elif isinstance(use_op, stream.ScheduleOp): + is_output = op.results[0] in use_op.outputs elif isinstance(use_op, func.Return): is_output = False else: diff --git a/compiler/transforms/set_memory_layout.py b/compiler/transforms/set_memory_layout.py index 051059e1..0bcc8bfb 100644 --- a/compiler/transforms/set_memory_layout.py +++ b/compiler/transforms/set_memory_layout.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from math import prod from xdsl.context import MLContext from xdsl.dialects import builtin, linalg @@ -17,6 +18,7 @@ from compiler.dialects.kernel import AddOp, QMacOp, RescaleOp from compiler.dialects.snax import LayoutCast from compiler.dialects.tsl import TiledStridedLayoutAttr +from compiler.ir.stream import Schedule, SchedulePattern from compiler.ir.tsl import Stride, TiledStride, TiledStridedLayout @@ -36,9 +38,7 @@ def match_and_rewrite(self, linalg_op: linalg.Generic, rewriter: PatternRewriter return shaped_operands: list[MemRefType] = [ - op.type - for op in linalg_op.operands - if isinstance(op.type, builtin.MemRefType) + op.type for op in linalg_op.operands if isinstance(op.type, builtin.MemRefType) ] m = shaped_operands[0].get_shape()[0] @@ -54,9 +54,7 @@ def match_and_rewrite(self, linalg_op: linalg.Generic, rewriter: PatternRewriter [ TiledStride( [ - Stride( - 256 * n // 8 if n else None, m // 8 if m else None - ), + Stride(256 * n // 8 if n else None, m // 8 if m else None), Stride(8, 8), ] ), @@ -70,9 +68,7 @@ def match_and_rewrite(self, linalg_op: linalg.Generic, rewriter: PatternRewriter [ TiledStride( [ - Stride( - 256 * n // 8 if n else None, m // 8 if m else None - ), + Stride(256 * n // 8 if n else None, m // 8 if m else None), Stride(8, 8), ] ), @@ -82,13 +78,9 @@ def match_and_rewrite(self, linalg_op: linalg.Generic, rewriter: PatternRewriter ) # insert layout_cast ops - new_input_a = LayoutCast.from_type_and_target_layout( - linalg_op.inputs[0], tsl_input - ) + new_input_a = LayoutCast.from_type_and_target_layout(linalg_op.inputs[0], tsl_input) - new_output = LayoutCast.from_type_and_target_layout( - linalg_op.outputs[0], tsl_output - ) + new_output = LayoutCast.from_type_and_target_layout(linalg_op.outputs[0], tsl_output) new_linalg_op = linalg.Generic( inputs=[new_input_a.dest], @@ -126,9 +118,7 @@ class AddMemoryLayout(RewritePattern): gemm_layout: GemmLayout = GemmLayout.cyclic @op_type_rewrite_pattern - def match_and_rewrite( - self, op: linalg.Generic | stream.StreamingRegionOp, rewriter: PatternRewriter - ): + def match_and_rewrite(self, op: linalg.Generic | stream.StreamingRegionOp, rewriter: PatternRewriter): # check if operation is dispatched via library call, as set by e.g. # the dispatch-kernels pass @@ -152,9 +142,10 @@ def match_and_rewrite( if not isinstance(op.body.block.first_op, QMacOp): return elif isinstance(op, stream.StreamingRegionOp): - assert isinstance( - generic_op := op.body.block.first_op, stream.GenericOp - ) + # TODO: this is a bit hacky, detect conv/gemm based on rank of input tensor: + if len(op.operands[0].type.get_shape()) > 2: + return + assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp) if not isinstance(generic_op.body.block.first_op, QMacOp): return @@ -167,9 +158,7 @@ def match_and_rewrite( # get m, n, k shaped_operands: list[tuple[int, MemRefType]] = [ - (index, op.type) - for index, op in enumerate(op.operands) - if isinstance(op.type, builtin.MemRefType) + (index, op.type) for index, op in enumerate(op.operands) if isinstance(op.type, builtin.MemRefType) ] m = shaped_operands[0][1].get_shape()[0] @@ -202,9 +191,7 @@ def match_and_rewrite( Stride(8, 8), ] ), - TiledStride( - [Stride(tile_stride, k // 8 if k else None), Stride(1, 8)] - ), + TiledStride([Stride(tile_stride, k // 8 if k else None), Stride(1, 8)]), ] ) ) @@ -214,9 +201,7 @@ def match_and_rewrite( tsl_input_b = TiledStridedLayoutAttr( TiledStridedLayout( [ - TiledStride( - [Stride(tile_stride, k // 8 if k else None), Stride(1, 8)] - ), + TiledStride([Stride(tile_stride, k // 8 if k else None), Stride(1, 8)]), TiledStride( [ Stride( @@ -236,9 +221,7 @@ def match_and_rewrite( [ TiledStride( [ - Stride( - 64 * n // 8 if n else None, m // 8 if m else None - ), + Stride(64 * n // 8 if n else None, m // 8 if m else None), Stride(8, 8), ] ), @@ -248,27 +231,143 @@ def match_and_rewrite( ) # insert layout_cast ops - new_input_a = LayoutCast.from_type_and_target_layout( - op.inputs[0], tsl_input_a - ) + new_input_a = LayoutCast.from_type_and_target_layout(op.inputs[0], tsl_input_a) + + new_input_b = LayoutCast.from_type_and_target_layout(op.inputs[1], tsl_input_b) + + new_output = LayoutCast.from_type_and_target_layout(op.outputs[0], tsl_output) + + rewriter.insert_op((new_input_a, new_input_b, new_output), InsertPoint.before(op)) + + if has_add_c: + rewriter.insert_op( + new_input_c := LayoutCast.from_type_and_target_layout(op.inputs[2], tsl_output), + InsertPoint.before(op), + ) + op.operands[shaped_operands[0][0]] = new_input_a.dest + op.operands[shaped_operands[1][0]] = new_input_b.dest + op.operands[shaped_operands[2][0]] = new_input_c.dest + op.operands[shaped_operands[3][0]] = new_output.dest + else: + op.operands[shaped_operands[0][0]] = new_input_a.dest + op.operands[shaped_operands[1][0]] = new_input_b.dest + op.operands[shaped_operands[2][0]] = new_output.dest + + +@dataclass +class AddConvMemoryLayout(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: linalg.Generic | stream.StreamingRegionOp, rewriter: PatternRewriter): + # check if operation is dispatched via library call, as set by e.g. + # the dispatch-kernels pass + + if isinstance(op, linalg.Generic): + if op.library_call is None: + return + else: + library_call = op.library_call.data + elif isinstance(op, stream.StreamingRegionOp): + if op.accelerator is None: + return + else: + library_call = op.accelerator.data + + has_add_c = False + + # check for library call + if library_call == "snax_gemmx" or library_call == "snax_gemmx_stream": + # only do so for qmac kernels + if isinstance(op, linalg.Generic): + if not isinstance(op.body.block.first_op, QMacOp): + return + elif isinstance(op, stream.StreamingRegionOp): + # TODO: this is a bit hacky, detect conv/gemm based on rank of input tensor: + if len(op.operands[0].type.get_shape()) != 4: + return + assert isinstance(generic_op := op.body.block.first_op, stream.GenericOp) + if not isinstance(generic_op.body.block.first_op, QMacOp): + return + + # the layout should be as static as the memref is. no more, no less + # get b, ox, oy, fx, fy, c, k + + shaped_operands: list[tuple[int, MemRefType]] = [ + (index, op.type) for index, op in enumerate(op.operands) if isinstance(op.type, builtin.MemRefType) + ] + + # do not alter existing set layouts + for _, memreftype in shaped_operands: + if isinstance(memreftype.layout, TiledStridedLayoutAttr): + return + + b = shaped_operands[0][1].get_shape()[0] + ix = shaped_operands[0][1].get_shape()[1] + iy = shaped_operands[0][1].get_shape()[2] + ox = shaped_operands[2][1].get_shape()[1] + oy = shaped_operands[2][1].get_shape()[2] + fx = shaped_operands[1][1].get_shape()[1] + fy = shaped_operands[1][1].get_shape()[2] + c = shaped_operands[0][1].get_shape()[3] + k = shaped_operands[2][1].get_shape()[3] + + # lots of current limitations: + assert b == 1 + assert ix > 0 + assert iy > 0 + assert ox > 0 + assert oy > 0 + assert fx > 0 + assert fy > 0 + assert c > 0 + assert c % 8 == 0 + assert k % 8 == 0 + assert ox % 8 == 0 - new_input_b = LayoutCast.from_type_and_target_layout( - op.inputs[1], tsl_input_b + tsl_input_a = TiledStridedLayoutAttr( + TiledStridedLayout( + [ + TiledStride([Stride(ix * iy * c, 1)]), # b + TiledStride([Stride(ix * 8, iy)]), # iy + TiledStride([Stride(8, ix)]), # ix + TiledStride([Stride(iy * ix * 8, c // 8), Stride(1, 8)]), # c + ] + ) ) - new_output = LayoutCast.from_type_and_target_layout( - op.outputs[0], tsl_output + tsl_input_b = TiledStridedLayoutAttr( + TiledStridedLayout( + [ + TiledStride([Stride(8 * fy * fx * c, k // 8), Stride(8, 8)]), # k + TiledStride([Stride(64 * fy, fx)]), # fy + TiledStride([Stride(64, fy)]), # fx + TiledStride([Stride(64 * fy * fx, c // 8), Stride(1, 8)]), # c + ] + ) ) - rewriter.insert_op( - (new_input_a, new_input_b, new_output), InsertPoint.before(op) + tsl_output = TiledStridedLayoutAttr( + TiledStridedLayout( + [ + TiledStride([Stride(ox * oy * k, 1)]), # b + TiledStride([Stride(8, oy)]), # oy + TiledStride([Stride(oy * 8, ox)]), # ox + TiledStride([Stride(ox * oy * 8, k // 8), Stride(1, 8)]), # k + ] + ) ) + # insert layout_cast ops + new_input_a = LayoutCast.from_type_and_target_layout(op.inputs[0], tsl_input_a) + + new_input_b = LayoutCast.from_type_and_target_layout(op.inputs[1], tsl_input_b) + + new_output = LayoutCast.from_type_and_target_layout(op.outputs[0], tsl_output) + + rewriter.insert_op((new_input_a, new_input_b, new_output), InsertPoint.before(op)) + if has_add_c: rewriter.insert_op( - new_input_c := LayoutCast.from_type_and_target_layout( - op.inputs[2], tsl_output - ), + new_input_c := LayoutCast.from_type_and_target_layout(op.inputs[2], tsl_output), InsertPoint.before(op), ) op.operands[shaped_operands[0][0]] = new_input_a.dest @@ -281,17 +380,99 @@ def match_and_rewrite( op.operands[shaped_operands[2][0]] = new_output.dest +@dataclass +class AddCyclicMemoryLayout(RewritePattern): + layout_idx: int = 0 + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: stream.ScheduleOp, rewriter: PatternRewriter): + + # do not alter existing set layouts + for operand in op.operands: + if isinstance(operand.type, MemRefType): + if isinstance(operand.type.layout, TiledStridedLayoutAttr): + return + + # recreate schedule from op + schedule = Schedule( + SchedulePattern(bounds=[x.data for x in bounds.data], pattern=pattern.data) + for pattern, bounds in zip(op.patterns.data, op.bounds.data) + ) + + def generate_one_list(n: int, i: int): + return [1 if j == i else 0 for j in range(n)] + + new_operands = [] + + + for operand, schedule_pattern in zip(op.operands, schedule): + assert isinstance(optype := operand.type, MemRefType) + + strides = [[] for i in range(optype.get_num_dims())] + current_stride = 1 + + for i in reversed(range(schedule_pattern.num_dims)): + result = schedule_pattern.pattern.eval(generate_one_list(schedule_pattern.num_dims, i), []) + result_1 = [1 if x else 0 for x in result] + if 1 in result_1: + dim = result_1.index(1) + existing_bound = prod(s.bound for s in strides[dim]) + dim_shape = -(optype.get_shape()[dim] // -existing_bound) # take ceildiv + # can we apply further tiling? + # FIXME: need better check, now relying on the fact that ix will never + # be divisible and fx will always be odd + if ( + # dim_shape % schedule_pattern.bounds[i] == 0 and + schedule_pattern.bounds[i] % 2 == 0 + and self.layout_idx == 0 + and dim_shape != 1 + ): + bound = schedule_pattern.bounds[i] + else: + bound = dim_shape + strid_obj = Stride(current_stride, bound) + + # maybe some padding for next stride + if bound < 8 and current_stride == 1: + current_stride = current_stride * 8 + # check: maybe possible with transpose? + elif bound < 8 and current_stride == 8: + current_stride = current_stride * 8 + else: + current_stride = current_stride * bound + + strides[result_1.index(1)].insert(0, strid_obj) + + for stride in strides: + if not stride: + stride.append(Stride(current_stride, 1)) + + layout = TiledStridedLayout([TiledStride(s) for s in strides]) + layout = layout.simplify() + tsl = TiledStridedLayoutAttr(layout) + + # insert layout_cast ops + new_operands.append(LayoutCast.from_type_and_target_layout(operand, tsl)) + + + rewriter.insert_op(new_operands, InsertPoint.before(op)) + + for i, new_operand in enumerate(new_operands): + op.operands[i] = new_operand.dest + + @dataclass(frozen=True) class SetMemoryLayout(ModulePass): name = "set-memory-layout" gemm_layout: str = "cyclic" + layout_idx: int = 0 def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - PatternRewriteWalker( - AddMemoryLayoutSIMD(), apply_recursively=False - ).rewrite_module(op) + PatternRewriteWalker(AddMemoryLayoutSIMD(), apply_recursively=False).rewrite_module(op) PatternRewriteWalker( AddMemoryLayout(gemm_layout=GemmLayout(self.gemm_layout)), apply_recursively=False, ).rewrite_module(op) + PatternRewriteWalker(AddConvMemoryLayout()).rewrite_module(op) + PatternRewriteWalker(AddCyclicMemoryLayout(1)).rewrite_module(op) diff --git a/compiler/transforms/set_memory_space.py b/compiler/transforms/set_memory_space.py index 70e91e55..ef3345e2 100644 --- a/compiler/transforms/set_memory_space.py +++ b/compiler/transforms/set_memory_space.py @@ -124,7 +124,7 @@ class InitStreamAndLinalgMemorySpace(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite( - self, op: linalg.Generic | stream.StreamingRegionOp, rewriter: PatternRewriter + self, op: linalg.Generic | stream.StreamingRegionOp | stream.ScheduleOp, rewriter: PatternRewriter ): operands_to_memory_cast = tuple( x diff --git a/compiler/transforms/stream_bufferize.py b/compiler/transforms/stream_bufferize.py index d06a483a..a3cdbeab 100644 --- a/compiler/transforms/stream_bufferize.py +++ b/compiler/transforms/stream_bufferize.py @@ -2,7 +2,7 @@ from xdsl.context import MLContext from xdsl.dialects import bufferization -from xdsl.dialects.builtin import MemRefType, ModuleOp, TensorType +from xdsl.dialects.builtin import MemRefType, ModuleOp, TensorType, UnitAttr from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -35,11 +35,14 @@ def match_and_rewrite( for operand in set(operands_to_buffer): assert isinstance(tensor_type := operand.type, TensorType) + properties = {} + properties["read_only"] = UnitAttr() tensor_to_memrefs[operand] = bufferization.ToMemrefOp( operands=[operand], result_types=( MemRefType(tensor_type.get_element_type(), tensor_type.get_shape()), ), + properties=properties ) # create new streaming region op that operates on the buffers diff --git a/compiler/transforms/test/test_remove_copies.py b/compiler/transforms/test/test_remove_copies.py new file mode 100644 index 00000000..3ddccda3 --- /dev/null +++ b/compiler/transforms/test/test_remove_copies.py @@ -0,0 +1,29 @@ + +from xdsl.context import MLContext +from xdsl.dialects import arith, builtin, func, memref +from xdsl.irdl import Operation +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.traits import SymbolTable + +from compiler.dialects.test import debug + + +class RemoveCopies(RewritePattern): + """Insert debugging function calls""" + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: memref.CopyOp, rewriter: PatternRewriter): + rewriter.erase_matched_op() + + +class RemoveCopiesPass(ModulePass): + name = "test-remove-copies" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker(RemoveCopies()).rewrite_module(op) diff --git a/compiler/util/canonicalize_affine.py b/compiler/util/canonicalize_affine.py index 767dbc00..eb0b86ad 100644 --- a/compiler/util/canonicalize_affine.py +++ b/compiler/util/canonicalize_affine.py @@ -27,7 +27,7 @@ def get_dim(expr: AffineExpr) -> int | None: return result -def canonicalize_addition(expr: AffineBinaryOpExpr) -> AffineExpr: +def canonicalize_addition(expr: AffineBinaryOpExpr, extreme = False) -> AffineExpr: """ Canonicalizes an addition by: putting the constant on rhs @@ -72,7 +72,7 @@ def canonicalize_addition(expr: AffineBinaryOpExpr) -> AffineExpr: return expr -def canonicalize_multiplication(expr: AffineBinaryOpExpr) -> AffineExpr: +def canonicalize_multiplication(expr: AffineBinaryOpExpr, extreme = False) -> AffineExpr: """ Canonicalizes a multiplication by: putting the constant on rhs @@ -94,7 +94,7 @@ def canonicalize_multiplication(expr: AffineBinaryOpExpr) -> AffineExpr: if isinstance(expr.rhs, AffineConstantExpr): # rhs is constant, lhs is not # multiplication by 1 can be omitted - if expr.rhs.value == 1: + if expr.rhs.value == 1 or extreme: return expr.lhs # turn (a + b) * cst into (a * cst) + (b * cst) if ( @@ -107,7 +107,7 @@ def canonicalize_multiplication(expr: AffineBinaryOpExpr) -> AffineExpr: return expr -def canonicalize_floordiv(expr: AffineBinaryOpExpr) -> AffineExpr: +def canonicalize_floordiv(expr: AffineBinaryOpExpr, extreme = False) -> AffineExpr: """ Canonicalizes a floordiv by: omitting a // 1 @@ -120,7 +120,7 @@ def canonicalize_floordiv(expr: AffineBinaryOpExpr) -> AffineExpr: return expr -def canonicalize_mod(expr: AffineBinaryOpExpr) -> AffineExpr: +def canonicalize_mod(expr: AffineBinaryOpExpr, extreme = False) -> AffineExpr: """ Canonicalizes a module operation by: replacing a % 1 by constant 0 @@ -133,38 +133,38 @@ def canonicalize_mod(expr: AffineBinaryOpExpr) -> AffineExpr: return expr -def canonicalize_binary_op(expr: AffineBinaryOpExpr) -> AffineExpr: +def canonicalize_binary_op(expr: AffineBinaryOpExpr, extreme = False) -> AffineExpr: expr = AffineBinaryOpExpr( - expr.kind, canonicalize_expr(expr.lhs), canonicalize_expr(expr.rhs) + expr.kind, canonicalize_expr(expr.lhs, extreme), canonicalize_expr(expr.rhs, extreme) ) if expr.kind is AffineBinaryOpKind.Add: - return canonicalize_addition(expr) + return canonicalize_addition(expr, extreme) if expr.kind is AffineBinaryOpKind.Mul: - return canonicalize_multiplication(expr) + return canonicalize_multiplication(expr, extreme) if expr.kind is AffineBinaryOpKind.FloorDiv: - return canonicalize_floordiv(expr) + return canonicalize_floordiv(expr, extreme) if expr.kind is AffineBinaryOpKind.Mod: - return canonicalize_mod(expr) + return canonicalize_mod(expr, extreme) return expr -def canonicalize_expr(expr: AffineExpr) -> AffineExpr: +def canonicalize_expr(expr: AffineExpr, extreme = False) -> AffineExpr: new_expr = expr if isinstance(expr, AffineBinaryOpExpr): - new_expr = canonicalize_binary_op(expr) + new_expr = canonicalize_binary_op(expr, extreme) if new_expr == expr: return new_expr - return canonicalize_expr(new_expr) + return canonicalize_expr(new_expr, extreme) # helper function to canonicalize affine maps -def canonicalize_map(map: AffineMap) -> AffineMap: +def canonicalize_map(map: AffineMap, extreme = False) -> AffineMap: # canonicalize each result expression of the map and construct new affine map return AffineMap( map.num_dims, map.num_symbols, - tuple(canonicalize_expr(expr) for expr in map.results), + tuple(canonicalize_expr(expr, extreme) for expr in map.results), ) diff --git a/compiler/util/dispatching_rules.py b/compiler/util/dispatching_rules.py index 7d9bb68f..3778bf42 100644 --- a/compiler/util/dispatching_rules.py +++ b/compiler/util/dispatching_rules.py @@ -21,4 +21,6 @@ def dispatch_to_compute(op): return True if isinstance(op, stream.StreamingRegionOp): return True + if isinstance(op, stream.ScheduleOp): + return True return False diff --git a/compiler/util/multiset.py b/compiler/util/multiset.py new file mode 100644 index 00000000..dcc910b8 --- /dev/null +++ b/compiler/util/multiset.py @@ -0,0 +1,60 @@ +from collections import Counter + + +class Multiset: + """ + In mathematics, a multiset (or bag, or mset) is a modification of the concept of a + set that, unlike a set, allows for multiple instances for each of its elements. + The number of instances given for each element is called the multiplicity of + that element in the multiset. + """ + + def __init__(self, iterable=None): + """Initialize the multiset with an optional iterable""" + self.counter = Counter(iterable) if iterable is not None else Counter() + + def add(self, item, count=1): + """Add an item to the multiset with a specified count (default is 1)""" + self.counter[item] += count + + def remove(self, item, count=1): + """Remove an item from the multiset by a specified count (default is 1)""" + if self.counter[item] >= count: + self.counter[item] -= count + if self.counter[item] <= 0: + del self.counter[item] + + def count(self, item): + """Return the count of an item in the multiset""" + return self.counter[item] + + def is_subset(self, other): + """Check if this multiset is a subset of another multiset""" + for item, count in self.counter.items(): + if other.counter[item] < count: + return False + return True + + def union(self, other): + """Return a new multiset that is the union of this and another multiset""" + return Multiset(self.counter + other.counter) + + def intersection(self, other): + """Return a new multiset that is the intersection of this and another multiset""" + return Multiset(self.counter & other.counter) + + def difference(self, other): + """Return a new multiset that is the difference of this and another multiset""" + return Multiset(self.counter - other.counter) + + def __contains__(self, item): + """Check if an element is in the multiset""" + return item in self.counter + + def __iter__(self): + """Allow iteration over the multiset as (element, cardinality) pairs""" + return iter(self.counter.items()) + + def __repr__(self): + """String representation of the multiset""" + return f"Multiset({self.counter})" diff --git a/kernels/conv/Makefile b/kernels/conv/Makefile new file mode 100644 index 00000000..3f02ec66 --- /dev/null +++ b/kernels/conv/Makefile @@ -0,0 +1,79 @@ +# Courtesy of Federico Ficarelli + +.DEFAULT_GOAL := all + +include /workspace/runtime/snax-gemmx.rules +include /workspace/runtime/Makefile.rules + +TESTS = +TESTS += conv.x +TESTS += generated.x + +MLIRPREPROCFLAGS = --linalg-generalize-named-ops +MLIRPREPROCFLAGS += --mlir-print-op-generic +MLIRPREPROCFLAGS += --mlir-print-local-scope + +ifndef SCHEDULE_IDX +SCHEDULE_IDX=0 +endif + +ifndef PURE_OUTPUT_STATIONARY +PURE_OUTPUT_STATIONARY=true +endif + +ifndef LAYOUT +LAYOUT=0 +endif + +%.preprocfinal.mlir: %.mlir + $(MLIROPT) $(MLIRPREPROCFLAGS) -o $@ $< + +# correct: +# SNAXOPTFLAGS = --print-op-generic -p 'convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,stream-bufferize,snax-bufferize,alloc-to-global,autoflow-scheduler{schedule_idx=${SCHEDULE_IDX} pure_output_stationary=${PURE_OUTPUT_STATIONARY}},set-memory-space,set-memory-layout,realize-memref-casts,insert-sync-barrier,dispatch-regions{nb_cores=3},autoflow-layout-resolution,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,snax-copy-to-dma,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space' + +# no copies: +SNAXOPTFLAGS = --print-op-generic -p 'convert-linalg-to-kernel,insert-accfg-op{accelerator=snax_gemmx},dispatch-kernels,convert-linalg-to-stream,fuse-streaming-regions,stream-bufferize,snax-bufferize,alloc-to-global,autoflow-scheduler{schedule_idx=${SCHEDULE_IDX} pure_output_stationary=${PURE_OUTPUT_STATIONARY}},set-memory-space,set-memory-layout{layout_idx=${LAYOUT}},realize-memref-casts,insert-sync-barrier,dispatch-regions{nb_cores=3},autoflow-layout-resolution,convert-linalg-to-accfg,test-add-mcycle-around-launch,convert-accfg-to-csr,test-remove-copies,memref-to-snax,snax-to-func,snax-lower-mcycle,clear-memory-space' + +CFLAGS += -std=gnu11 +CFLAGS += -Wall -Wextra + +%.x: %.o main.o + $(LD) $(LDFLAGS) $^ -o $@ + +sim_%: % + rm -fr ./logs/ + $(VLTSIM) $< + +RUN = $(addprefix run_, $(TESTS)) +$(RUN): run_%: sim_% + mv logs $(subst sim_,,$<).logs + +all: $(TESTS) + +allrun: $(RUN) + +clean: + rm -fr *.ll12 *.x *.o *.logs/ logs/ data.h data.c *.mlir + +# List of directories +CONV_DIRS := $(wildcard results/*/) + +# Variable to count completed jobs +COMPLETED := 0 + +# Default target +convs: $(CONV_DIRS) + @echo "All jobs completed! Processed $(COMPLETED)/$(words $(CONV_DIRS)) directories." + +# Target for each directory +$(CONV_DIRS): + @echo "Processing $@... ($(COMPLETED)/$(words $(CONV_DIRS)))" + @$(MAKE) -C $@ sim_generated.x > $@/sim_generated.log 2>&1 + @echo "sim_generated.x completed for $@" >> $@/sim_generated.log + @$(MAKE) -C $@ traces >> $@/traces.log 2>&1 + @echo "traces completed for $@" >> $@/traces.log + @$(MAKE) -C $@ clean >> $@/cleaning.log 2>&1 + @echo "cleaning completed for $@" >> $@/cleaning.log + $(eval COMPLETED := $(shell echo $$(( $(COMPLETED) + 1 )))) + +.PHONY: all $(CONV_DIRS) diff --git a/kernels/conv/analyze.py b/kernels/conv/analyze.py new file mode 100644 index 00000000..87c23d44 --- /dev/null +++ b/kernels/conv/analyze.py @@ -0,0 +1,47 @@ +import os +import json +import pandas as pd + +# Define the base directory for test folders +base_dir = "results" +conv_dirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir)] + +data = [] + +# Process each directory +for dir_path in conv_dirs: + # Initialize entry + entry = {"Directory": dir_path, "Errors": None, "Cycles": None} + + # Check sim_generated.log for errors + sim_log_file = os.path.join(dir_path, "sim_generated.log") + if os.path.isfile(sim_log_file): + try: + with open(sim_log_file, 'r') as log_file: + for line in log_file: + if "Finished, nb errors:" in line: + entry["Errors"] = int(line.split(":")[-1].strip()) + break + except Exception as e: + entry["Errors"] = "Error reading log" + + # Check trace file for cycles + trace_file = os.path.join(dir_path, "logs/trace_chip_00_hart_00000.trace.json") + if os.path.isfile(trace_file): + try: + with open(trace_file, 'r') as trace_file_content: + trace_data = json.load(trace_file_content) + if len(trace_data) > 2 and 'cycles' in trace_data[2]: + entry["Cycles"] = trace_data[2]['cycles'] + except Exception as e: + entry["Cycles"] = "Error reading trace" + + data.append(entry) + +# Convert to DataFrame for better readability and export +df = pd.DataFrame(data) + +# Save or display the results +df.to_csv("test_results_summary.csv", index=False) +print(df) + diff --git a/kernels/conv/genbenchmark.py b/kernels/conv/genbenchmark.py new file mode 100644 index 00000000..1a2ae4a1 --- /dev/null +++ b/kernels/conv/genbenchmark.py @@ -0,0 +1,211 @@ +import itertools +import pathlib +import shutil + +from genkernel import ConvSpec, generate_conv_ir +from xdsl.ir import StringIO +from xdsl.printer import Printer + +from util.snax_benchmark import SNAXBenchmark + + +def write_module_to_file(module, file): + output = StringIO() + printer = Printer(stream=output) + printer.print(module) + with open(file, "w") as output_file: + output_file.write(output.getvalue()) + + +def write_makefile(file, schedule_idx=None): + with open(file, "w") as output_file: + output_file.write("include $(realpath ../../Makefile)\n") + if schedule_idx is not None: + output_file.write(f"SCHEDULE_IDX={schedule_idx}\n") + output_file.write("clean:\n") + output_file.write("\trm -rf logs/*.txt logs/*.dasm\n") + + +def generate_layout_benchmark_gemm(): + + selected_dims = [128, 136, 144, 152] + selected_layouts = [0, 1] + + specs = [ + ConvSpec(1, m, 1, 1, 1, n, k) + for m, n, k in itertools.product(selected_dims, repeat=3) + ] + + for layout in selected_layouts: + for i, spec in enumerate(specs): + module = generate_conv_ir(spec, generate_constants=False) + + binary = "generated.x" + bm = SNAXBenchmark( + kernel=f"layout_{layout}_{i:03d}", + binary=binary, + src_dir=str(pathlib.Path.cwd()), + export_dir=str(pathlib.Path.cwd()), + output_dir=str(pathlib.Path.cwd()), + ) + + bm.clean() + write_module_to_file(module, "generated.mlir") + bm.build(["PURE_OUTPUT_STATIONARY=true", f"LAYOUT={layout}"]) + # bm.run() + # bm.trace() + bm.copy_binary("") + write_makefile(bm.export_dir / "Makefile") + # bm.copy_logs('') + bm.clean() + + +def generate_layout_benchmark_conv(): + + selected_dims = [32, 40, 48, 56, 64, 72, 80] + selected_layouts = [0, 1] + + specs = [ + ConvSpec(1, 8, 16, 5, 5, c, k) + for c, k in itertools.product(selected_dims, repeat=2) + ] + + for layout in selected_layouts: + for i, spec in enumerate(specs): + module = generate_conv_ir(spec, generate_constants=False) + + binary = "generated.x" + bm = SNAXBenchmark( + kernel=f"layout_{layout}_{i:03d}", + binary=binary, + src_dir=str(pathlib.Path.cwd()), + export_dir=str(pathlib.Path.cwd()), + output_dir=str(pathlib.Path.cwd()), + ) + + bm.clean() + write_module_to_file(module, "generated.mlir") + bm.build(["PURE_OUTPUT_STATIONARY=true", f"LAYOUT={layout}"]) + # bm.run() + # bm.trace() + bm.copy_binary("") + write_makefile(bm.export_dir / "Makefile") + # bm.copy_logs('') + bm.clean() + + +def generate_temp_mapping_benchmark(): + # conv 3x3 + # spec = ConvSpec(1, 32, 16, 3, 3, 64, 64) + + # conv 7x7 + # spec = ConvSpec(1, 16, 16, 7, 7, 64, 16) + + # pw conv + # spec = ConvSpec(1, 14, 14, 1, 1, 1024, 256) + + # gemm + # spec = ConvSpec(1, 256, 1, 1, 1, 256, 256) + + # strided conv + # spec = ConvSpec(1, 16, 16, 4, 4, 64, 64) + + # dilated conv + spec = ConvSpec(1, 16, 16, 5, 5, 64, 64, dilation=2) + + module = generate_conv_ir(spec, generate_constants=False) + + for schedule_idx in range(720): + binary = "generated.x" + bm = SNAXBenchmark( + kernel=f"conv_{schedule_idx:04d}", + binary=binary, + src_dir=str(pathlib.Path.cwd()), + export_dir=str(pathlib.Path.cwd()), + output_dir=str(pathlib.Path.cwd()), + ) + + # bm.cdst_folder = pathlib.Path(self.export_dir / folder) + if not bm.export_dir.exists(): + bm.export_dir.mkdir(parents=True) + # write_module_to_file(module, bm.export_dir / "generated.mlir") + write_module_to_file(module, bm.export_dir / "generated.mlir") + # bm.build([f"SCHEDULE_IDX={schedule_idx}", "PURE_OUTPUT_STATIONARY=false"]) + # bm.run() + # bm.trace() + shutil.copy(src=bm.src_dir / "main.c", dst=bm.export_dir / "main.c") + write_makefile(bm.export_dir / "Makefile", schedule_idx) + # bm.copy_logs('') + # bm.clean() + + +def generate_resnet_benchmark(): + specs = [ + # input layer: + ConvSpec(1, 112, 112, 7, 7, 3, 64, stride=2), + # block 1: + # downsampling layer: + ConvSpec(1, 56, 56, 1, 1, 64, 64), + # bottleneck: + ConvSpec(1, 56, 56, 1, 1, 256, 64), + ConvSpec(1, 56, 56, 3, 3, 64, 64), + ConvSpec(1, 56, 56, 1, 1, 64, 256), + # block 2: + # downsampling layer: + ConvSpec(1, 28, 28, 1, 1, 256, 512, stride=2), + # bottleneck: + ConvSpec(1, 28, 28, 1, 1, 256, 128), + ConvSpec(1, 28, 28, 1, 1, 512, 128), + ConvSpec(1, 28, 28, 3, 3, 128, 128), + ConvSpec(1, 28, 28, 1, 1, 128, 512), + # block 3: + # downsampling layer: + ConvSpec(1, 14, 14, 1, 1, 512, 1024, stride=2), + # bottleneck: + ConvSpec(1, 14, 14, 1, 1, 512, 256), + ConvSpec(1, 14, 14, 1, 1, 1024, 256), + ConvSpec(1, 14, 14, 3, 3, 256, 256), + ConvSpec(1, 14, 14, 1, 1, 256, 1024), + # block 4: + # downsampling layer: + ConvSpec(1, 7, 7, 1, 1, 1024, 2048, stride=2), + # bottleneck: + ConvSpec(1, 7, 7, 1, 1, 1024, 512), + ConvSpec(1, 7, 7, 1, 1, 2048, 512), + ConvSpec(1, 7, 7, 3, 3, 512, 512), + ConvSpec(1, 7, 7, 1, 1, 512, 2048), + ] + + for i, spec in enumerate(specs): + + module = generate_conv_ir(spec, generate_constants=False) + + binary = "generated.x" + bm = SNAXBenchmark( + kernel=f"resnet50_{i:03d}", + binary=binary, + src_dir=str(pathlib.Path.cwd()), + export_dir=str(pathlib.Path.cwd()), + output_dir=str(pathlib.Path.cwd()), + ) + + if not bm.export_dir.exists(): + bm.export_dir.mkdir(parents=True) + # write_module_to_file(module, bm.export_dir / "generated.mlir") + write_module_to_file(module, bm.export_dir / "generated.mlir") + # bm.build([f"SCHEDULE_IDX={schedule_idx}", "PURE_OUTPUT_STATIONARY=false"]) + # bm.run() + # bm.trace() + shutil.copy(src=bm.src_dir / "main.c", dst=bm.export_dir / "main.c") + write_makefile(bm.export_dir / "Makefile") + # bm.copy_logs('') + bm.clean() + + + + +if __name__ == "__main__": + # generate_layout_benchmark_gemm() + # generate_layout_benchmark_conv() + generate_temp_mapping_benchmark() +# generate_resnet_benchmark() diff --git a/kernels/conv/genkernel.py b/kernels/conv/genkernel.py new file mode 100644 index 00000000..addb7d19 --- /dev/null +++ b/kernels/conv/genkernel.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass + +import numpy as np +import torch.nn as nn +from torch import Tensor +from xdsl.builder import Builder +from xdsl.dialects import arith, builtin, func, linalg, tensor +from xdsl.dialects.builtin import i8, i32, i64 +from xdsl.ir import BlockArgument, StringIO +from xdsl.parser import DenseIntOrFPElementsAttr +from xdsl.printer import Printer + + +@dataclass(frozen=True) +class ConvSpec: + b: int + ox: int + oy: int + fx: int + fy: int + c: int + k: int + groups: int = 1 + stride: int = 1 + dilation: int = 1 + + +def generate_conv_ir(spec: ConvSpec, generate_constants: bool = True): + # cop + low_bound = 0 + high_bound = 10 + + I_size = [spec.b, spec.c, (spec.oy * spec.stride) + spec.dilation*(spec.fy - 1), (spec.stride * spec.ox) + spec.dilation*(spec.fx - 1)] + W_size = [spec.k, spec.c, spec.fy, spec.fx] + + if generate_constants: + I = np.random.randint(low_bound, high_bound, size=I_size, dtype=np.dtype("int8")) + input_vals: list[int] = I.flatten().tolist() + W = np.random.randint(low_bound, high_bound, size=W_size, dtype=np.dtype("int8")) + weight_vals: list[int] = W.flatten().tolist() + + conv = nn.Conv2d(spec.c, spec.k, spec.fx) + conv.weight = nn.Parameter(Tensor(W)) + + O = np.round(conv(Tensor(I)).detach().numpy()).astype(np.int32) + output_vals: list[int] = O.flatten().tolist() + + else: + input_vals = [0] + output_vals = [0] + weight_vals = [0] + + input_type = builtin.TensorType(i8, (spec.b, spec.c, (spec.oy * spec.stride) + (spec.fy - 1) * spec.dilation, (spec.ox * spec.stride) + (spec.fx - 1)*spec.dilation)) + kernel_type = builtin.TensorType(i8, (spec.k, spec.c, spec.fy, spec.fx)) + output_type = builtin.TensorType(i32, (spec.b, spec.k, spec.oy, spec.ox)) + + # function returns golden output + computed output + res_types = [output_type, output_type] + + dilations = DenseIntOrFPElementsAttr.tensor_from_list([spec.dilation], i64, [2]) + strides = DenseIntOrFPElementsAttr.tensor_from_list([spec.stride], i64, [2]) + + @Builder.implicit_region([]) + def func_body(_) -> None: + inputs = arith.Constant(DenseIntOrFPElementsAttr.from_list(input_type, input_vals), input_type) + weights = arith.Constant(DenseIntOrFPElementsAttr.from_list(kernel_type, weight_vals), kernel_type) + golden = arith.Constant(DenseIntOrFPElementsAttr.from_list(output_type, output_vals), output_type) + empty_tensor = tensor.EmptyOp([], output_type) + result = linalg.Conv2DNchwFchwOp( + dilations, strides, (inputs.result, weights.result), (empty_tensor.results[0],) + ) + func.Return(result, golden) + + function = func.FuncOp.from_region("conv", [], res_types, func_body) + + module = builtin.ModuleOp([function]) + + return module + + +if __name__ == "__main__": + spec = ConvSpec(1, 16, 16, 3, 3, 16, 16) + + output = StringIO() + printer = Printer(stream=output) + printer.print(generate_conv_ir(spec)) + with open('conv.mlir', "w") as output_file: + output_file.write(output.getvalue()) + diff --git a/kernels/conv/main.c b/kernels/conv/main.c new file mode 100644 index 00000000..4e41b573 --- /dev/null +++ b/kernels/conv/main.c @@ -0,0 +1,81 @@ +#include "stdint.h" + +#include "memref.h" +#include "snax_rt.h" + +/* + * These libraries are included from github.com/KULeuven-MICAS/snitch_cluster + * Interested users, might want to look at: + * + * /sw/snRuntime/api + * /target/snitch_cluster/sw/runtime/rtl/src + * /target/snitch_cluster/sw/runtime/common + * */ +#include + +// Kernel provided via external definition +void _mlir_ciface_conv(int32_t *test); + +int main() { + { + + + FourDMemrefI32_t results[2]; + + FourDMemrefI32_t *golden, *computed; + golden = &results[0]; + computed = &results[1]; + + int32_t* some_data = results; + + // allocate zero row in tcdm + snrt_l1alloc(256); + + (void)snrt_mcycle(); + snrt_cluster_hw_barrier(); + + _mlir_ciface_conv(results); + + snrt_cluster_hw_barrier(); + (void)snrt_mcycle(); + + // Correctness check - + // from this point on only core 0 is required to be alive. + int thiscore = snrt_cluster_core_idx(); + if (thiscore != 0) + return 0; + + printf("Got golden result:\n"); + printf("Pointer at %p\n", golden->data); + printf("Aligned Pointer at %p\n", golden->aligned_data); + + + printf("Got computed result:\n"); + printf("Pointer at %p\n", computed->data); + printf("Aligned Pointer at %p\n", computed->aligned_data); + + int total_results = 1; + for (int i = 0; i < 4; i++) total_results *= computed->shape[i]; + + printf("Not checking %d results tho lmao...\n", total_results); + return 0; + + printf("Checking %d results...\n", total_results); + + int nerr = 0; + + for(int i = 0; i < total_results; i ++) { + + if (golden->aligned_data[i] != computed->aligned_data[i]) { + printf("(%d) %d -> %d\n", i, golden->aligned_data[i], computed->aligned_data[i]); + nerr++; + } + } + + printf("Finished, nb errors: %d\n", nerr); + + if (nerr > 0) return 1; + else return 0; + + } +} diff --git a/runtime/include/memref.h b/runtime/include/memref.h index 43eb75f4..66fbce59 100644 --- a/runtime/include/memref.h +++ b/runtime/include/memref.h @@ -42,7 +42,29 @@ struct TwoDMemrefI8 { uint32_t stride[2]; }; +struct FourDMemrefI32 { + int32_t *data; // allocated pointer: Pointer to data buffer as allocated, + // only used for deallocating the memref + int32_t *aligned_data; // aligned pointer: Pointer to properly aligned data + // that memref indexes + uint32_t offset; + uint32_t shape[4]; + uint32_t stride[4]; +}; + +struct FourDMemrefI8 { + int8_t *data; // allocated pointer: Pointer to data buffer as allocated, + // only used for deallocating the memref + int8_t *aligned_data; // aligned pointer: Pointer to properly aligned data + // that memref indexes + uint32_t offset; + uint32_t shape[4]; + uint32_t stride[4]; +}; + typedef struct OneDMemrefI32 OneDMemrefI32_t; typedef struct OneDMemrefI64 OneDMemrefI64_t; typedef struct TwoDMemrefI8 TwoDMemrefI8_t; typedef struct TwoDMemrefI32 TwoDMemrefI32_t; +typedef struct FourDMemrefI32 FourDMemrefI32_t; +typedef struct FourDMemrefI8 FourDMemrefI8_t; diff --git a/tests/filecheck/transforms/convert-stream-to-snax-stream.mlir b/tests/filecheck/transforms/convert-stream-to-snax-stream.mlir index b82d87e8..d2e75aa5 100644 --- a/tests/filecheck/transforms/convert-stream-to-snax-stream.mlir +++ b/tests/filecheck/transforms/convert-stream-to-snax-stream.mlir @@ -42,7 +42,7 @@ func.func @streamer_matmul(%arg0 : memref<16x16xi8>, %arg1 : memref<16x16xi8, st func.return } -// CHECK: "snax_stream.streaming_region"(%1, %2, %3, %4, %5) <{"stride_patterns" = [#snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern], "accelerator" = "snax_gemmx", "operandSegmentSizes" = array}> ({ +// CHECK: "snax_stream.streaming_region"(%1, %2, %3, %4, %5) <{"stride_patterns" = [#snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern, #snax_stream.stride_pattern], "accelerator" = "snax_gemmx", "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%6 : !stream.stream, %7 : !stream.stream, %8 : !stream.stream): // CHECK-NEXT: %9 = "stream.generic"(%6, %7, %0, %0) <{"library_call" = "snax_gemmx"}> ({ // CHECK-NEXT: ^1(%arg3 : i8, %arg4 : i8, %arg5 : i32, %arg6 : i32, %arg7 : i32): diff --git a/tests/ir/stream/test_access_pattern.py b/tests/ir/stream/test_access_pattern.py index 8fd84217..7f712bb9 100644 --- a/tests/ir/stream/test_access_pattern.py +++ b/tests/ir/stream/test_access_pattern.py @@ -221,7 +221,9 @@ def test_template_pattern_matches(): sp_non_matching_pattern = SchedulePattern( bounds, AffineMap( - num_dims=2, num_symbols=0, results=(AffineDimExpr(1), AffineDimExpr(0)) + num_dims=2, + num_symbols=0, + results=(AffineDimExpr(1) + AffineDimExpr(0), AffineDimExpr(0)), ), ) sp_non_matching_bounds = SchedulePattern((5, 15), pattern) @@ -318,7 +320,9 @@ def test_template_matches_schedule(): sp3 = SchedulePattern( (10, 20), AffineMap( - num_dims=2, num_symbols=0, results=(AffineDimExpr(1), AffineDimExpr(0)) + num_dims=2, + num_symbols=0, + results=(AffineDimExpr(1) + AffineDimExpr(0), AffineDimExpr(0)), ), ) schedule_non_matching = Schedule([sp1, sp3]) diff --git a/tests/ir/stream/test_scheduler.py b/tests/ir/stream/test_scheduler.py index d17887df..2f4ca0e0 100644 --- a/tests/ir/stream/test_scheduler.py +++ b/tests/ir/stream/test_scheduler.py @@ -18,7 +18,7 @@ def test_matching_1o(): resulting_schedule = scheduler(template, schedule) - assert schedule == resulting_schedule.clear_unused_dims() + assert template.matches(resulting_schedule.clear_unused_dims()) def test_matching_2o(): @@ -30,7 +30,7 @@ def test_matching_2o(): resulting_schedule = scheduler(template, schedule) - assert schedule == resulting_schedule.clear_unused_dims() + assert template.matches(resulting_schedule.clear_unused_dims()) def test_tiling_1o1_1d(): diff --git a/tests/util/test_multiset.py b/tests/util/test_multiset.py new file mode 100644 index 00000000..6b7814ec --- /dev/null +++ b/tests/util/test_multiset.py @@ -0,0 +1,67 @@ +from collections import Counter + +from compiler.util.multiset import Multiset + + +def test_multiset_initialization(): + multiset = Multiset([1, 2, 2, 3]) + assert multiset.counter == Counter([1, 2, 2, 3]) + + +def test_multiset_add(): + multiset = Multiset([1, 2]) + multiset.add(2) + multiset.add(3, 2) + assert multiset.counter == Counter([1, 2, 2, 3, 3]) + + +def test_multiset_remove(): + multiset = Multiset([1, 2, 2, 3]) + multiset.remove(2) + multiset.remove(3) + assert multiset.counter == Counter([1, 2]) + + +def test_multiset_remove_item_not_present(): + multiset = Multiset([1, 2]) + multiset.remove(3) # Removing an item that isn't there shouldn't throw an error + assert multiset.counter == Counter([1, 2]) + + +def test_multiset_count(): + multiset = Multiset([1, 2, 2, 3]) + assert multiset.count(2) == 2 + assert multiset.count(3) == 1 + + +def test_multiset_is_subset(): + multiset1 = Multiset([1, 2, 2, 3]) + multiset2 = Multiset([1, 2, 2, 3, 3, 4]) + assert multiset1.is_subset(multiset2) + + +def test_multiset_is_not_subset(): + multiset1 = Multiset([1, 2, 2, 3, 5]) + multiset2 = Multiset([1, 2, 2, 3, 3, 4]) + assert not multiset1.is_subset(multiset2) + + +def test_multiset_union(): + multiset1 = Multiset([1, 2, 2]) + multiset2 = Multiset([2, 3]) + union_set = multiset1.union(multiset2) + assert union_set.counter == Counter([1, 2, 2, 2, 3]) + + +def test_multiset_intersection(): + multiset1 = Multiset([1, 2, 2, 3]) + multiset2 = Multiset([2, 2, 3, 4]) + intersection_set = multiset1.intersection(multiset2) + assert intersection_set.counter == Counter([2, 2, 3]) + + +def test_multiset_difference(): + multiset1 = Multiset([1, 2, 2, 3]) + multiset2 = Multiset([2, 3]) + difference_set = multiset1.difference(multiset2) + assert difference_set.counter == Counter([1, 2])