From 5daf5f681f961abb81465ea9d4cac116aa141d90 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sun, 7 Jul 2024 15:07:30 +0100 Subject: [PATCH 1/2] dialects: (snitch_stream) add custom syntax to streaming region --- .../filecheck/dialects/snitch_stream/ops.mlir | 32 ++-- .../convert_memref_stream_to_snitch.mlir | 148 +++++++++++------- .../snitch_register_allocation.mlir | 10 +- xdsl/dialects/snitch_stream.py | 107 ++++++++++++- 4 files changed, 220 insertions(+), 77 deletions(-) diff --git a/tests/filecheck/dialects/snitch_stream/ops.mlir b/tests/filecheck/dialects/snitch_stream/ops.mlir index 783104f511..a18eebf6f8 100644 --- a/tests/filecheck/dialects/snitch_stream/ops.mlir +++ b/tests/filecheck/dialects/snitch_stream/ops.mlir @@ -3,10 +3,11 @@ %X, %Y, %Z = "test.op"() : () -> (!riscv.reg, !riscv.reg, !riscv.reg) -"snitch_stream.streaming_region"(%X, %Y, %Z) <{ - "stride_patterns" = [#snitch_stream.stride_pattern], - "operandSegmentSizes" = array -}> ({ +snitch_stream.streaming_region { + stride_patterns = [ + #snitch_stream.stride_pattern + ] +} ins(%X, %Y : !riscv.reg, !riscv.reg) outs(%Z : !riscv.reg) { ^0(%a_stream : !stream.readable>, %b_stream : !stream.readable>, %c_stream : !stream.writable>): %c5 = riscv.li 5 : !riscv.reg riscv_snitch.frep_outer %c5 { @@ -15,21 +16,24 @@ %c = riscv.fadd.d %a, %b : (!riscv.freg, !riscv.freg) -> !riscv.freg riscv_snitch.write %c to %c_stream : !riscv.freg } -}) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () +} // CHECK: %X, %Y, %Z = "test.op"() : () -> (!riscv.reg, !riscv.reg, !riscv.reg) -// CHECK-NEXT: "snitch_stream.streaming_region"(%X, %Y, %Z) <{"stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } ins(%X, %Y : !riscv.reg, !riscv.reg) outs(%Z : !riscv.reg) { // CHECK-NEXT: ^0(%a_stream : !stream.readable>, %b_stream : !stream.readable>, %c_stream : !stream.writable>): -// CHECK-NEXT: %c5 = riscv.li 5 : !riscv.reg -// CHECK-NEXT: riscv_snitch.frep_outer %c5 { -// CHECK-NEXT: %a = riscv_snitch.read from %a_stream : !riscv.freg -// CHECK-NEXT: %b = riscv_snitch.read from %b_stream : !riscv.freg -// CHECK-NEXT: %c = riscv.fadd.d %a, %b : (!riscv.freg, !riscv.freg) -> !riscv.freg -// CHECK-NEXT: riscv_snitch.write %c to %c_stream : !riscv.freg +// CHECK-NEXT: %c5 = riscv.li 5 : !riscv.reg +// CHECK-NEXT: riscv_snitch.frep_outer %c5 { +// CHECK-NEXT: %a = riscv_snitch.read from %a_stream : !riscv.freg +// CHECK-NEXT: %b = riscv_snitch.read from %b_stream : !riscv.freg +// CHECK-NEXT: %c = riscv.fadd.d %a, %b : (!riscv.freg, !riscv.freg) -> !riscv.freg +// CHECK-NEXT: riscv_snitch.write %c to %c_stream : !riscv.freg +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: }) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () - // CHECK-GENERIC: %X, %Y, %Z = "test.op"() : () -> (!riscv.reg, !riscv.reg, !riscv.reg) // CHECK-GENERIC-NEXT: "snitch_stream.streaming_region"(%X, %Y, %Z) <{"stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ diff --git a/tests/filecheck/transforms/convert_memref_stream_to_snitch.mlir b/tests/filecheck/transforms/convert_memref_stream_to_snitch.mlir index 8d15fb25aa..cb22352767 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_snitch.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_snitch.mlir @@ -30,16 +30,22 @@ memref_stream.streaming_region { } // CHECK-NEXT: %A, %B, %C = "test.op"() : () -> (memref<2xf64>, memref<3xf64>, memref<3x2xf64>) -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %A : memref<2xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %B : memref<3xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %C : memref<3x2xf64> to !riscv.reg -// CHECK-NEXT: "snitch_stream.streaming_region"(%{{.*}}, %{{.*}}, %{{.*}}) <{"stride_patterns" = [#snitch_stream.stride_pattern, #snitch_stream.stride_pattern, #snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^0(%a : !stream.readable, %b : !stream.readable, %c : !stream.writable): -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %a : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %b : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %c : !stream.writable to !stream.writable -// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}, %{{.*}}) : (!stream.readable, !stream.readable, !stream.writable) -> () -// CHECK-NEXT: }) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: %A_1 = builtin.unrealized_conversion_cast %A : memref<2xf64> to !riscv.reg +// CHECK-NEXT: %B_1 = builtin.unrealized_conversion_cast %B : memref<3xf64> to !riscv.reg +// CHECK-NEXT: %C_1 = builtin.unrealized_conversion_cast %C : memref<3x2xf64> to !riscv.reg +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern, +// CHECK-NEXT: #snitch_stream.stride_pattern, +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } ins(%A_1, %B_1 : !riscv.reg, !riscv.reg) outs(%C_1 : !riscv.reg) { +// CHECK-NEXT: ^{{.*}}(%a : !stream.readable, %b : !stream.readable, %c : !stream.writable): +// CHECK-NEXT: %a_1 = builtin.unrealized_conversion_cast %a : !stream.readable to !stream.readable +// CHECK-NEXT: %b_1 = builtin.unrealized_conversion_cast %b : !stream.readable to !stream.readable +// CHECK-NEXT: %c_1 = builtin.unrealized_conversion_cast %c : !stream.writable to !stream.writable +// CHECK-NEXT: "test.op"(%a_1, %b_1, %c_1) : (!stream.readable, !stream.readable, !stream.writable) -> () +// CHECK-NEXT: } memref_stream.streaming_region { patterns = [ @@ -51,15 +57,18 @@ memref_stream.streaming_region { "test.op"(%c0, %c1) : (!stream.readable, !stream.readable) -> () } -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %C : memref<3x2xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %C : memref<3x2xf64> to !riscv.reg -// CHECK-NEXT: "snitch_stream.streaming_region"(%{{.*}}, %{{.*}}) <{"stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^{{.*}}(%c0 : !stream.readable, %c1 : !stream.readable): -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %c0 : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %c1 : !stream.readable to !stream.readable -// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!stream.readable, !stream.readable) -> () -// CHECK-NEXT: }) : (!riscv.reg, !riscv.reg) -> () - +// CHECK-NEXT: %C_2 = builtin.unrealized_conversion_cast %C : memref<3x2xf64> to !riscv.reg +// CHECK-NEXT: %C_3 = builtin.unrealized_conversion_cast %C : memref<3x2xf64> to !riscv.reg +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } ins(%C_2, %C_3 : !riscv.reg, !riscv.reg) { +// CHECK-NEXT: ^{{.*}}(%c0 : !stream.readable, %c1 : !stream.readable): +// CHECK-NEXT: %c0_1 = builtin.unrealized_conversion_cast %c0 : !stream.readable to !stream.readable +// CHECK-NEXT: %c1_1 = builtin.unrealized_conversion_cast %c1 : !stream.readable to !stream.readable +// CHECK-NEXT: "test.op"(%c0_1, %c1_1) : (!stream.readable, !stream.readable) -> () +// CHECK-NEXT: } %D, %E = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) // CHECK-NEXT: %D, %E = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) @@ -74,14 +83,19 @@ memref_stream.streaming_region { "test.op"(%d_stream, %e_stream) : (!stream.readable, !stream.readable) -> () } -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : memref<1x1x8x8xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : memref<1x1x3x3xf64> to !riscv.reg -// CHECK-NEXT: "snitch_stream.streaming_region"(%{{.*}}, %{{.*}}) <{"stride_patterns" = [#snitch_stream.stride_pattern, #snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable): -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.readable to !stream.readable -// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!stream.readable, !stream.readable) -> () -// CHECK-NEXT: }) : (!riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: %D_1 = builtin.unrealized_conversion_cast %D : memref<1x1x8x8xf64> to !riscv.reg +// CHECK-NEXT: %E_1 = builtin.unrealized_conversion_cast %E : memref<1x1x3x3xf64> to !riscv.reg +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern, +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } ins(%D_1, %E_1 : !riscv.reg, !riscv.reg) { +// CHECK-NEXT: ^{{.*}}(%d_stream : !stream.readable, %e_stream : !stream.readable): +// CHECK-NEXT: %d_stream_1 = builtin.unrealized_conversion_cast %d_stream : !stream.readable to !stream.readable +// CHECK-NEXT: %e_stream_1 = builtin.unrealized_conversion_cast %e_stream : !stream.readable to !stream.readable +// CHECK-NEXT: "test.op"(%d_stream_1, %e_stream_1) : (!stream.readable, !stream.readable) -> () +// CHECK-NEXT: } %F = "test.op"() : () -> memref<8x8xf64> // CHECK-NEXT: %F = "test.op"() : () -> memref<8x8xf64> @@ -97,16 +111,22 @@ memref_stream.streaming_region { "test.op"(%x_stream, %w_stream, %b_stream) : (!stream.readable, !stream.readable, !stream.readable) -> () } -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : memref<8x8xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : memref<8x8xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : memref<8x8xf64> to !riscv.reg -// CHECK-NEXT: "snitch_stream.streaming_region"(%{{.*}}, %{{.*}}, %{{.*}}) <{"stride_patterns" = [#snitch_stream.stride_pattern, #snitch_stream.stride_pattern, #snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.readable): -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.readable to !stream.readable -// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}, %{{.*}}) : (!stream.readable, !stream.readable, !stream.readable) -> () -// CHECK-NEXT: }) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: %F_1 = builtin.unrealized_conversion_cast %F : memref<8x8xf64> to !riscv.reg +// CHECK-NEXT: %F_2 = builtin.unrealized_conversion_cast %F : memref<8x8xf64> to !riscv.reg +// CHECK-NEXT: %F_3 = builtin.unrealized_conversion_cast %F : memref<8x8xf64> to !riscv.reg +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern, +// CHECK-NEXT: #snitch_stream.stride_pattern, +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } ins(%F_1, %F_2, %F_3 : !riscv.reg, !riscv.reg, !riscv.reg) { +// CHECK-NEXT: ^{{.*}}(%x_stream : !stream.readable, %w_stream : !stream.readable, %b_stream : !stream.readable): +// CHECK-NEXT: %x_stream_1 = builtin.unrealized_conversion_cast %x_stream : !stream.readable to !stream.readable +// CHECK-NEXT: %w_stream_1 = builtin.unrealized_conversion_cast %w_stream : !stream.readable to !stream.readable +// CHECK-NEXT: %b_stream_1 = builtin.unrealized_conversion_cast %b_stream : !stream.readable to !stream.readable +// CHECK-NEXT: "test.op"(%x_stream_1, %w_stream_1, %b_stream_1) : (!stream.readable, !stream.readable, !stream.readable) -> () +// CHECK-NEXT: } %G, %H = "test.op"() : () -> (f64, memref<16x16xf64>) // CHECK-NEXT: %G, %H = "test.op"() : () -> (f64, memref<16x16xf64>) @@ -125,26 +145,30 @@ memref_stream.streaming_region { } } -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : memref<16x16xf64> to !riscv.reg -// CHECK-NEXT: "snitch_stream.streaming_region"(%{{.*}}) <{"stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.writable): -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.writable to !stream.writable -// CHECK-NEXT: %{{.*}} = arith.constant 0 : i32 -// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32 -// CHECK-NEXT: %{{.*}} = arith.constant 256 : i32 -// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 { -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.writable to !stream.writable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : f64 to !riscv.freg +// CHECK-NEXT: %H_1 = builtin.unrealized_conversion_cast %H : memref<16x16xf64> to !riscv.reg +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } outs(%H_1 : !riscv.reg) { +// CHECK-NEXT: ^{{.*}}(%h_stream : !stream.writable): +// CHECK-NEXT: %h_stream_1 = builtin.unrealized_conversion_cast %h_stream : !stream.writable to !stream.writable +// CHECK-NEXT: %c0_2 = arith.constant 0 : i32 +// CHECK-NEXT: %c1_2 = arith.constant 1 : i32 +// CHECK-NEXT: %c256 = arith.constant 256 : i32 +// CHECK-NEXT: scf.for %i = %c0_2 to %c256 step %c1_2 : i32 { +// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %h_stream_1 : !stream.writable to !stream.writable +// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %G : f64 to !riscv.freg // CHECK-NEXT: %{{.*}} = riscv.fmv.d %{{.*}} : (!riscv.freg) -> !riscv.freg // CHECK-NEXT: riscv_snitch.write %{{.*}} to %{{.*}} : !riscv.freg // CHECK-NEXT: } -// CHECK-NEXT: }) : (!riscv.reg) -> () +// CHECK-NEXT: } %I, %J, %K = "test.op"() : () -> (memref<3x5xf64>, memref<5x8xf64>, memref<3x8xf64>) // CHECK-NEXT: %I, %J, %K = "test.op"() : () -> (memref<3x5xf64>, memref<5x8xf64>, memref<3x8xf64>) -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %I : memref<3x5xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %J : memref<5x8xf64> to !riscv.reg -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %K : memref<3x8xf64> to !riscv.reg +// CHECK-NEXT: %I_1 = builtin.unrealized_conversion_cast %I : memref<3x5xf64> to !riscv.reg +// CHECK-NEXT: %J_1 = builtin.unrealized_conversion_cast %J : memref<5x8xf64> to !riscv.reg +// CHECK-NEXT: %K_1 = builtin.unrealized_conversion_cast %K : memref<3x8xf64> to !riscv.reg // more complex maps memref_stream.streaming_region { @@ -158,13 +182,19 @@ memref_stream.streaming_region { %res = "test.op"() : () -> f64 memref_stream.yield %res : f64 } -// CHECK-NEXT: "snitch_stream.streaming_region"(%{{.*}}, %{{.*}}, %{{.*}}) <{"stride_patterns" = [#snitch_stream.stride_pattern, #snitch_stream.stride_pattern, #snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.writable): -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.readable to !stream.readable -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !stream.writable to !stream.writable -// CHECK-NEXT: %{{.*}} = "test.op"() : () -> f64 -// CHECK-NEXT: memref_stream.yield %{{.*}} : f64 -// CHECK-NEXT: }) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern, +// CHECK-NEXT: #snitch_stream.stride_pattern, +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } ins(%I_1, %J_1 : !riscv.reg, !riscv.reg) outs(%K_1 : !riscv.reg) { +// CHECK-NEXT: ^{{.*}}(%i_1 : !stream.readable, %j : !stream.readable, %k : !stream.writable): +// CHECK-NEXT: %i_2 = builtin.unrealized_conversion_cast %i_1 : !stream.readable to !stream.readable +// CHECK-NEXT: %j_1 = builtin.unrealized_conversion_cast %j : !stream.readable to !stream.readable +// CHECK-NEXT: %k_1 = builtin.unrealized_conversion_cast %k : !stream.writable to !stream.writable +// CHECK-NEXT: %res = "test.op"() : () -> f64 +// CHECK-NEXT: memref_stream.yield %res : f64 +// CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tests/filecheck/transforms/snitch_register_allocation.mlir b/tests/filecheck/transforms/snitch_register_allocation.mlir index 83f560927f..beeed82309 100644 --- a/tests/filecheck/transforms/snitch_register_allocation.mlir +++ b/tests/filecheck/transforms/snitch_register_allocation.mlir @@ -19,8 +19,12 @@ // CHECK: builtin.module { // CHECK-NEXT: %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!riscv.reg, !riscv.reg, !riscv.reg) -// CHECK-NEXT: "snitch_stream.streaming_region"(%ptr0, %ptr1, %ptr2) <{"stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^0(%s0 : !stream.readable>, %s1 : !stream.readable>, %s2 : !stream.writable>): +// CHECK-NEXT: snitch_stream.streaming_region { +// CHECK-NEXT: patterns = [ +// CHECK-NEXT: #snitch_stream.stride_pattern +// CHECK-NEXT: ] +// CHECK-NEXT: } ins(%ptr0, %ptr1 : !riscv.reg, !riscv.reg) outs(%ptr2 : !riscv.reg) { +// CHECK-NEXT: ^{{.*}}(%s0 : !stream.readable>, %s1 : !stream.readable>, %s2 : !stream.writable>): // CHECK-NEXT: %c5 = riscv.li 5 : !riscv.reg // CHECK-NEXT: riscv_snitch.frep_outer %c5 { // CHECK-NEXT: %x = riscv_snitch.read from %s0 : !riscv.freg @@ -28,6 +32,6 @@ // CHECK-NEXT: %r = riscv.fadd.d %x, %y : (!riscv.freg, !riscv.freg) -> !riscv.freg // CHECK-NEXT: riscv_snitch.write %r to %s2 : !riscv.freg // CHECK-NEXT: } -// CHECK-NEXT: }) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () +// CHECK-NEXT: } // CHECK-NEXT: } diff --git a/xdsl/dialects/snitch_stream.py b/xdsl/dialects/snitch_stream.py index aaf1101e72..0e6cdaeb99 100644 --- a/xdsl/dialects/snitch_stream.py +++ b/xdsl/dialects/snitch_stream.py @@ -23,6 +23,7 @@ from collections.abc import Iterator, Sequence from itertools import product +from typing import cast from xdsl.dialects import riscv from xdsl.dialects.builtin import ( @@ -46,7 +47,7 @@ region_def, var_operand_def, ) -from xdsl.parser import AttrParser +from xdsl.parser import AttrParser, Parser from xdsl.printer import Printer from xdsl.traits import NoTerminator from xdsl.utils.exceptions import VerifyException @@ -248,6 +249,110 @@ def __init__( }, ) + def print(self, printer: Printer): + with printer.indented(): + printer.print_string(" {") + if self.stride_patterns.data: + printer.print_string("\npatterns = [") + with printer.indented(): + printer.print_list( + self.stride_patterns.data, + lambda attr: printer.print("\n", attr), + delimiter=",", + ) + printer.print_string("\n]") + else: + printer.print_string("\npatterns = []") + printer.print_string("\n}") + + if self.inputs: + printer.print_string(" ins(") + printer.print_list(self.inputs, printer.print_ssa_value) + printer.print_string(" : ") + printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_string(")") + + if self.outputs: + printer.print_string(" outs(") + printer.print_list(self.outputs, printer.print_ssa_value) + printer.print_string(" : ") + printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_string(")") + + if self.attributes: + printer.print(" attrs = ") + printer.print_op_attributes(self.attributes) + + printer.print_string(" ") + printer.print_region(self.body) + + @classmethod + def parse(cls, parser: Parser) -> StreamingRegionOp: + parser.parse_punctuation("{") + parser.parse_identifier("stride_patterns") + parser.parse_punctuation("=") + + patterns = parser.parse_attribute() + if not isinstance(patterns, ArrayAttr): + parser.raise_error(f"Expected ArrayAttr {patterns}") + patterns = cast(ArrayAttr[Attribute], patterns) + for pattern in patterns: + if not isinstance(pattern, StridePattern): + parser.raise_error(f"Expected StridePattern {pattern}") + patterns = cast(ArrayAttr[StridePattern], patterns) + + parser.parse_punctuation("}") + + pos = parser.pos + if parser.parse_optional_characters("ins"): + parser.parse_punctuation("(") + unresolved_ins = parser.parse_comma_separated_list( + Parser.Delimiter.NONE, parser.parse_unresolved_operand + ) + parser.parse_punctuation(":") + ins_types = parser.parse_comma_separated_list( + Parser.Delimiter.NONE, parser.parse_type + ) + parser.parse_punctuation(")") + ins = parser.resolve_operands(unresolved_ins, ins_types, pos) + else: + ins = () + + pos = parser.pos + if parser.parse_optional_characters("outs"): + parser.parse_punctuation("(") + unresolved_outs = parser.parse_comma_separated_list( + Parser.Delimiter.NONE, parser.parse_unresolved_operand + ) + parser.parse_punctuation(":") + outs_types = parser.parse_comma_separated_list( + Parser.Delimiter.NONE, parser.parse_type + ) + parser.parse_punctuation(")") + outs = parser.resolve_operands(unresolved_outs, outs_types, pos) + else: + outs = () + + if parser.parse_optional_keyword("attrs"): + parser.parse_punctuation("=") + extra_attrs = parser.expect( + parser.parse_optional_attr_dict, "expect extra attributes" + ) + else: + extra_attrs = {} + + body = parser.parse_region() + + generic = cls( + ins, + outs, + patterns, + body, + ) + generic.attributes |= extra_attrs + + return generic + SnitchStream = Dialect( "snitch_stream", From 116041433cf8d876a8fdcdc50bd33e5cc4f439a4 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sun, 7 Jul 2024 16:18:21 +0100 Subject: [PATCH 2/2] use attributes to mark attributes --- xdsl/dialects/snitch_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xdsl/dialects/snitch_stream.py b/xdsl/dialects/snitch_stream.py index 0e6cdaeb99..a8efb3b191 100644 --- a/xdsl/dialects/snitch_stream.py +++ b/xdsl/dialects/snitch_stream.py @@ -280,7 +280,7 @@ def print(self, printer: Printer): printer.print_string(")") if self.attributes: - printer.print(" attrs = ") + printer.print(" attributes = ") printer.print_op_attributes(self.attributes) printer.print_string(" ") @@ -333,7 +333,7 @@ def parse(cls, parser: Parser) -> StreamingRegionOp: else: outs = () - if parser.parse_optional_keyword("attrs"): + if parser.parse_optional_keyword("attributes"): parser.parse_punctuation("=") extra_attrs = parser.expect( parser.parse_optional_attr_dict, "expect extra attributes"