From 162315b1e2e6199e5cd22c7c5204497522386464 Mon Sep 17 00:00:00 2001 From: zhangqirui Date: Tue, 2 Jul 2024 15:28:41 +0800 Subject: [PATCH] feat(triton-linalg): add ci for triton-linalg --- .github/ci_script/combine_log.py | 52 ++ .github/ci_script/file_guard.py | 37 ++ .github/ci_script/triton-linalg-ci_script.sh | 79 +++ .github/workflows/triton-linalg_ci.yaml | 21 + test/Conversion/arith-to-linalg.mlir | 10 +- test/Conversion/triton-to-linalg.mlir | 619 ++++++++++++++++-- test/Dialect/LinalgExt/invalid.mlir | 149 ++--- test/Dialect/LinalgExt/ops.mlir | 106 ++- .../Dialect/Triton/extract-move-backward.mlir | 6 +- .../Triton/extractslice-move-backward.mlir | 4 +- test/Pipelines/pipeline.mlir | 39 ++ test/lit.cfg.py | 2 +- tools/ci/daily/triton-linalg_daliy.pipeline | 130 ++++ tools/scripts/lint_check/common.sh | 27 + tools/scripts/lint_check/format_diff.py | 171 +++++ tools/scripts/lint_check/lint.sh | 49 ++ tools/scripts/test_triton-linalg.sh | 60 ++ 17 files changed, 1363 insertions(+), 198 deletions(-) create mode 100644 .github/ci_script/combine_log.py create mode 100644 .github/ci_script/file_guard.py create mode 100644 .github/ci_script/triton-linalg-ci_script.sh create mode 100644 .github/workflows/triton-linalg_ci.yaml create mode 100644 tools/ci/daily/triton-linalg_daliy.pipeline create mode 100644 tools/scripts/lint_check/common.sh create mode 100755 tools/scripts/lint_check/format_diff.py create mode 100755 tools/scripts/lint_check/lint.sh create mode 100644 tools/scripts/test_triton-linalg.sh diff --git a/.github/ci_script/combine_log.py b/.github/ci_script/combine_log.py new file mode 100644 index 0000000..60c245a --- /dev/null +++ b/.github/ci_script/combine_log.py @@ -0,0 +1,52 @@ +import time +import sys +import os +import argparse +''' + Get the result information fed back from the job server. If it is 'success' or 'failed', exit the pipeline. Otherwise, continue to monitor job information every 2 seconds. + output_path: the target file that you want to combine sub log with. + list_path: the list of sub log name. When it is updated, the correspondding file will be add to output tail. + list_dir_path: the dir path where sub logs stored. + status_path: the path of status file. When status file is written to "success" or "fail", exit script. +''' + +def combine_log(output_path, list_path, list_dir_path, status_path): + # list_pos stores the last position that pointer of list file pointed to. + list_pos = 0 + while True: + list_file = open(list_path, 'r') + list_file.seek(list_pos) + # read all lines starting from list_pos. + items = list_file.readlines() + # update list_pos + list_pos = list_file.tell() + # if read any line + if items is not None: + items.sort() + for item in items: + sub_path = item.strip() + if sub_path != "": + file_name = list_dir_path + '/' + sub_path + # while True: + if os.path.exists(file_name): + os.system('cat ' + file_name + ' >> ' + output_path) + # break + # check status_file, when read "success" or "fail" exit cycle, or else, sleep some seconds and start from beginning. + status_file = open(status_path) + status = status_file.readline().strip() + status_file.close() + if "fail" in status or "success" in status or "Success" in status or "Fail" in status or "error" in status or "Error" in status: + break + else: + time.sleep(2) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Monitor and concatenate files based on a list.") + parser.add_argument('output_path', type=str, help='The path to the output file.') + parser.add_argument('list_path', type=str, help='The path to the list file containing sub-paths.') + parser.add_argument('list_dir_path', type=str, help='The base directory where sub-paths are located.') + parser.add_argument('status_path', type=str, help='The path to the status file.') + + args = parser.parse_args() + combine_log(args.output_path, args.list_path, args.list_dir_path, args.status_path) + diff --git a/.github/ci_script/file_guard.py b/.github/ci_script/file_guard.py new file mode 100644 index 0000000..6981849 --- /dev/null +++ b/.github/ci_script/file_guard.py @@ -0,0 +1,37 @@ +import time +import sys +import os + +def file_guard(): + # where stores the last position that pointer pointed to. + where= 0 + while True: + file = open(guard_log_file, "r") + file.seek(where) + # if read any lines, call system echo to print each line. + for line in file.readlines(): + new_line = line.strip().replace("\'", "_").replace("\"", "_") + os.system('echo ' + "'%s'" % new_line) + # update where + where = file.tell() + file.close() + # check status, end process when read "success" or "fail" + status_file = open(guard_status_file, "r") + line = status_file.readline().strip() + status_file.close() + if "success" in line or "Success" in line: + print("Task success.") + break + elif "fail" in line or "Fail" in line or "error" in line or "Error" in line: + print("Task Fail.") + exit(-1) + # sleep for a while + time.sleep(2) +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Monitor a log file and echo lines, check status to stop.") + parser.add_argument('guard_status_file', type=str, help='The path to the status file.') + parser.add_argument('guard_log_file', type=str, help='The path to the log file.') + + args = parser.parse_args() + + file_guard(args.guard_status_file, args.guard_log_file) diff --git a/.github/ci_script/triton-linalg-ci_script.sh b/.github/ci_script/triton-linalg-ci_script.sh new file mode 100644 index 0000000..f0daf11 --- /dev/null +++ b/.github/ci_script/triton-linalg-ci_script.sh @@ -0,0 +1,79 @@ +# /bin/bash +# Get PR id +PR_string=$(echo $GITHUB_REF | grep -Eo "/[0-9]*/") +pr_id=(${PR_string//// }) + +# Generate time stamp +current=`date "+%Y-%m-%d %H:%M:%S"` +timeStamp=`date -d "$current" +%s` +currentTimeStamp=$((timeStamp*1000+10#`date "+%N"`/1000000)) + +# Temporally set to mlu370 +card_type="MLU370-S4" + +# Default repo name +repo_name="triton-linalg" +# Repo ci root path +repo_root="/home/user1/${repo_name}_ci/" +if [ ! -d $repo_root ];then + mkdir $repo_root +fi +# Repo ci requests path +requests_path="$repo_root/requests" +if [ ! -d $requests_path ];then + mkdir $requests_path +fi + +# Gen name of this ci +request_name="${repo_name}_${pr_id}_${currentTimeStamp}_${card_type}.rqt" + +# Gen file and dir for this request +request_root="$repo_root/$request_name/" +sub_logs_path="$request_root/sub_logs/" + + +if [ ! -d $request_root ];then + mkdir $request_root +fi + +if [ ! -d $sub_logs_path ];then + mkdir $sub_logs_path +fi + +echo "working" > "$request_root/status" +chmod o+w "$request_root/status" + +if [ ! -f "$request_root/log" ];then + touch "$request_root/log" +fi + +chmod o+w "$request_root/log" + +if [ ! -f "$request_root/log_list" ];then + touch "$request_root/log_list" +fi + +chmod o+w "$request_root/log_list" + +# Gen request file. + +echo "repo:${repo_name}" > "$requests_path/${request_name}" +echo "pr_id:${pr_id}" >> "$requests_path/${request_name}" +echo "timestamp:${currentTimeStamp}" >> "$requests_path/${request_name}" + +# change dir group for server and client, or when server/client try to delete request, ftp may raise error. +# start script +python3 .github/ci_script/file_guard.py "$request_root/status" "$request_root/log" & +python3 .github/ci_script/combine_log.py "$request_root/log" "$request_root/log_list" "$request_root/sub_logs" "$request_root/status" & + +wait + +status=$( head -n +1 ${request_root}/status ) + +if [ "$status" != "success" ];then + echo "${status}" + exit -1 +else + echo "${status}" + exit 0 +fi diff --git a/.github/workflows/triton-linalg_ci.yaml b/.github/workflows/triton-linalg_ci.yaml new file mode 100644 index 0000000..e1de201 --- /dev/null +++ b/.github/workflows/triton-linalg_ci.yaml @@ -0,0 +1,21 @@ +name: triton-linalg_ci + +on: + push: + branches: [master] + pull_request: + branches: [master] +jobs: + test: + strategy: + matrix: + triton-linalg_version : [v1.1.1] + runs-on: self-hosted + steps: + - uses: actions/checkout@v3 + with: + submodules: 'true' + + - name: run_triton-linalg_ci + run: > + bash .github/ci_script/triton-linalg-ci_script.sh diff --git a/test/Conversion/arith-to-linalg.mlir b/test/Conversion/arith-to-linalg.mlir index c350b35..474f521 100644 --- a/test/Conversion/arith-to-linalg.mlir +++ b/test/Conversion/arith-to-linalg.mlir @@ -27,7 +27,7 @@ func.func @const_valid_int(%arg0: tensor<1x16x128x128xi32>) -> tensor<1x16x128x1 // ----- func.func @arith_addi(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) { // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<128xi32> - // CHECK: %[[MAPPED:.*]] = linalg.map { arith.addi } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) + // CHECK: %[[MAPPED:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) %0 = arith.addi %arg0, %arg1 : tensor<128xi32> return } @@ -35,7 +35,7 @@ func.func @arith_addi(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) { // ----- func.func @arith_subi(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) { // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<128xi32> - // CHECK: %[[MAPPED:.*]] = linalg.map { arith.subi } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) + // CHECK: %[[MAPPED:.*]] = linalg.map { arith.subi {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) %0 = arith.subi %arg0, %arg1 : tensor<128xi32> return } @@ -43,7 +43,7 @@ func.func @arith_subi(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) { // ----- func.func @arith_muli(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) { // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<128xi32> - // CHECK: %[[MAPPED:.*]] = linalg.map { arith.muli } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) + // CHECK: %[[MAPPED:.*]] = linalg.map { arith.muli {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) %0 = arith.muli %arg0, %arg1 : tensor<128xi32> return } @@ -131,7 +131,7 @@ func.func @arith_xori(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) { // ----- func.func @arith_shli(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) { // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<128xi32> - // CHECK: %[[MAPPED:.*]] = linalg.map { arith.shli } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) + // CHECK: %[[MAPPED:.*]] = linalg.map { arith.shli {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor<128xi32>, tensor<128xi32>) outs(%[[INIT]] : tensor<128xi32>) %0 = arith.shli %arg0, %arg1 : tensor<128xi32> return } @@ -388,7 +388,7 @@ func.func @arith_addi_dynamic(%arg0: tensor<128x?xi32>, %arg1: tensor<128x?xi32> // CHECK: %[[CST:.*]] = arith.constant 1 : index // CHECK: %[[DYNAMIC_DIM:.*]] = tensor.dim %arg0, %[[CST]] : tensor<128x?xi32> // CHECK: %[[INIT:.*]] = tensor.empty(%[[DYNAMIC_DIM]]) : tensor<128x?xi32> - // CHECK: %[[MAPPED:.*]] = linalg.map { arith.addi } ins(%arg0, %arg1 : tensor<128x?xi32>, tensor<128x?xi32>) outs(%[[INIT]] : tensor<128x?xi32>) + // CHECK: %[[MAPPED:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%arg0, %arg1 : tensor<128x?xi32>, tensor<128x?xi32>) outs(%[[INIT]] : tensor<128x?xi32>) %0 = arith.addi %arg0, %arg1 : tensor<128x?xi32> return } diff --git a/test/Conversion/triton-to-linalg.mlir b/test/Conversion/triton-to-linalg.mlir index fa1e73e..dc4b1eb 100644 --- a/test/Conversion/triton-to-linalg.mlir +++ b/test/Conversion/triton-to-linalg.mlir @@ -744,6 +744,24 @@ tt.func @for_iter_args(%arg0: !tt.ptr, %arg1: tensor<128x64xi32>) { tt.return } +// ----- +tt.func @ext_elemwise_1(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) { + // CHECK: tensor.empty + // CHECK: linalg_ext.libdevice_call + // CHECK-DAG: symbol = "__cn_vector_mul_f32_rn" + %0 = tt.extern_elementwise %arg0, %arg1 {libname = "a", libpath = "b", symbol = "__cn_vector_mul_f32_rn", pure = true} : (tensor<16x16xf32>, tensor<16x16xf32>) -> (tensor<16x16xf32>) + tt.return +} + +// ----- +tt.func @ext_elemwise_2(%arg0: tensor<16x16xi32>) { + // CHECK: tensor.empty + // CHECK: linalg_ext.libdevice_call + // CHECK-DAG: symbol = "__cn_vector_abs_s32" + %0 = tt.extern_elementwise %arg0 {libname = "a", libpath = "b", symbol = "__cn_vector_abs_s32", pure = true} : (tensor<16x16xi32>) -> (tensor<16x16xi32>) + tt.return +} + // ----- // CHECK-LABEL: @cast_ptr_and_int_scalar tt.func @cast_ptr_and_int_scalar(%arg0: !tt.ptr) { @@ -822,6 +840,76 @@ tt.func @trans_0d(%arg0: tensor) -> tensor { tt.return %out : tensor } +// ----- +func.func public @scalar_pow(%arg0: f32, %arg1: f32) { + // CHECK: linalg_ext.scalar_libdevice_call + %0 = tt.extern_elementwise %arg0, %arg1 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_pow_f32", pure = true} : (f32, f32) -> f32 + return +} + +// ----- +func.func public @scalar_scalbn(%arg0: f32, %arg1: i32) { + // CHECK: linalg_ext.scalar_libdevice_call + %0 = tt.extern_elementwise %arg0, %arg1 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_scalbn_f32", pure = true} : (f32, i32) -> f32 + return +} + +// ----- +func.func @scalar_isinf(%arg0: f16) -> i16 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_isinf_f16", pure = true} : (f16) -> i16 + func.return %res : i16 +} + +// ----- +func.func @scalar_addf(%arg0: f32, %arg1: f32) -> f32 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0, %arg1 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_add_f32_tz", pure = true} : (f32, f32) -> f32 + func.return %res : f32 +} + +// ----- +func.func @scalar_addi(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0, %arg1 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_add_u32", pure = true} : (i32, i32) -> i32 + func.return %res : i32 +} + +// ----- +func.func @scalar_and(%arg0: i8, %arg1: i8) -> i8 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0, %arg1 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_and_bool", pure = true} : (i8, i8) -> i8 + func.return %res : i8 +} + +// ----- +func.func @scalar_or(%arg0: i8, %arg1: i8) -> i8 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0, %arg1 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_or_bool", pure = true} : (i8, i8) -> i8 + func.return %res : i8 +} + +// ----- +func.func @scalar_isnan(%arg0: f32) -> i32 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_isnan_f32", pure = true} : (f32) -> i32 + func.return %res : i32 +} + +// ----- +func.func @scalar_cast_to_ui8(%arg0: f32) -> i8 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_cast_f32_to_u8_tz", pure = true} : (f32) -> i8 + func.return %res : i8 +} + +// ----- +func.func @scalar_lt(%arg0: i16, %arg1: i16) -> i8 { + // CHECK: linalg_ext.scalar_libdevice_call + %res = tt.extern_elementwise %arg0, %arg1 {libname = "libdevice", libpath = "", symbol = "__cn_scalar_lt_u16", pure = true} : (i16, i16) -> i8 + func.return %res : i8 +} + // ----- // CHECK-LABEL: @cmpi_to_fill func.func @cmpi_to_fill(%arg0: i32) { @@ -1229,6 +1317,60 @@ func.func @print_scalar(%arg0: i32, %arg1: f32, %arg2: !tt.ptr, %arg3: i64) return } +// ----- +// CHECK-LABEL: @print_scalar_hex +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: !tt.ptr, %[[ARG3:.*]]: i64) +// CHECK: %[[ARG4:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : !tt.ptr to i64 +// CHECK: %[[ARG5:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG6:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG7:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG5]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG6]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG7]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "0x%08x"} +// CHECK: %[[ARG8:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG9:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG10:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG8]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG9]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG10]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "arg0: \0A0x%08x"} +// CHECK: %[[ARG11:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG12:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG13:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG11]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG12]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG13]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "arg0, arg10x%08x"} +// CHECK: aux.scalar.print(%[[ARG1]] : f32) {format = "arg0, arg10x%08x"} +// CHECK: %[[ARG14:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG15:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG16:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG14]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG15]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG16]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG4]] : i64) {format = "arg2: %p"} +// CHECK: %[[ARG17:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG18:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG19:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG17]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG18]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG19]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG3]] : i64) {format = "arg3: \0A0x%016llx"} +func.func @print_scalar_hex(%arg0: i32, %arg1: f32, %arg2: !tt.ptr, %arg3: i64) { + tt.print "" { hex = true } : %arg0 : i32 + tt.print "arg0: \n" { hex = true } : %arg0 : i32 + tt.print "arg0, arg1" { hex = true } : %arg0, %arg1 : i32, f32 + tt.print "arg2: " { hex = true } : %arg2 : !tt.ptr + tt.print "arg3: \n" { hex = true } : %arg3 : i64 + return +} + // ----- // CHECK-LABEL: @print_tensor // CHECK-SAME: %[[ARG0:.*]]: tensor<16xi32>, %[[ARG1:.*]]: tensor<2x8xf32>, %[[ARG2:.*]]: tensor<16x!tt.ptr>, %[[ARG3:.*]]: tensor<32xi64>) @@ -1283,11 +1425,47 @@ func.func @print_tensor(%arg0: tensor<16xi32>, %arg1: tensor<2x8xf32>, %arg2: te return } +// ----- +// CHECK-LABEL: @print_tensor_hex +// CHECK-SAME: %[[ARG0:.*]]: tensor<16xi32>, %[[ARG1:.*]]: tensor<2x8xf32>, %[[ARG2:.*]]: tensor<16x!tt.ptr>) +// CHECK: %[[ARG3:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : tensor<16x!tt.ptr> to tensor<16xi64> +// CHECK: %[[ARG4:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG5:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG6:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG4]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG5]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG6]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: %[[ARG7:.*]] = aux.print(%[[ARG0]] : tensor<16xi32>) {format = "arg0: 0x%08x"} -> (tensor<16xi32>) +// CHECK: %[[ARG8:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG9:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG10:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG8]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG9]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG10]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: %[[ARG11:.*]] = aux.print(%[[ARG7]] : tensor<16xi32>) {format = "arg0, arg10x%08x"} -> (tensor<16xi32>) +// CHECK: %[[ARG12:.*]] = aux.print(%[[ARG1]] : tensor<2x8xf32>) {format = "arg0, arg10x%08x"} -> (tensor<2x8xf32>) +// CHECK: %[[ARG13:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG14:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG15:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG13]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG14]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG15]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: %[[ARG16:.*]] = aux.print(%[[ARG3]] : tensor<16xi64>) {format = "arg2: %p"} -> (tensor<16xi64>) +func.func @print_tensor_hex(%arg0: tensor<16xi32>, %arg1: tensor<2x8xf32>, %arg2: tensor<16x!tt.ptr>) { + tt.print "arg0: " { hex = true } : %arg0 : tensor<16xi32> + tt.print "arg0, arg1" { hex = true } : %arg0, %arg1 : tensor<16xi32>, tensor<2x8xf32> + tt.print "arg2: " { hex = true } : %arg2 : tensor<16x!tt.ptr> + return +} + // ----- // CHECK-LABEL: @print_scalar_and_tensor // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: tensor<16xi32>, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: tensor<2x8xf32>, %[[ARG5:.*]]: tensor<16x!tt.ptr>, %[[ARG6:.*]]: !tt.ptr, %[[ARG7:.*]]: i64, %[[ARG8:.*]]: tensor<32xi64>) -// CHECK: %[[ARG9:.*]] = builtin.unrealized_conversion_cast %arg5 : tensor<16x!tt.ptr> to tensor<16xi64> -// CHECK: %[[ARG10:.*]] = builtin.unrealized_conversion_cast %arg6 : !tt.ptr to i64 +// CHECK: %[[ARG9:.*]] = builtin.unrealized_conversion_cast %[[ARG5]] : tensor<16x!tt.ptr> to tensor<16xi64> +// CHECK: %[[ARG10:.*]] = builtin.unrealized_conversion_cast %[[ARG6]] : !tt.ptr to i64 // CHECK: %[[ARG11:.*]] = tt.get_program_id x : i32 // CHECK: %[[ARG12:.*]] = tt.get_program_id y : i32 // CHECK: %[[ARG13:.*]] = tt.get_program_id z : i32 @@ -1314,55 +1492,99 @@ func.func @print_tensor(%arg0: tensor<16xi32>, %arg1: tensor<2x8xf32>, %arg2: te // CHECK: aux.scalar.print(%[[ARG20]] : i32) {format = ", "} // CHECK: aux.scalar.print(%[[ARG21]] : i32) {format = ", "} // CHECK: aux.scalar.print {format = ") "} -// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "arg0, arg1, arg2"} -// CHECK: %[[ARG22:.*]] = aux.print(%[[ARG18]] : tensor<16xi32>) {format = "arg0, arg1, arg2"} -> (tensor<16xi32>) -// CHECK: aux.scalar.print(%[[ARG2]] : f32) {format = "arg0, arg1, arg2"} -// CHECK: %[[ARG23:.*]] = tt.get_program_id x : i32 -// CHECK: %[[ARG24:.*]] = tt.get_program_id y : i32 -// CHECK: %[[ARG25:.*]] = tt.get_program_id z : i32 -// CHECK: aux.scalar.print(%[[ARG23]] : i32) {format = "pid ("} -// CHECK: aux.scalar.print(%[[ARG24]] : i32) {format = ", "} -// CHECK: aux.scalar.print(%[[ARG25]] : i32) {format = ", "} -// CHECK: aux.scalar.print {format = ") "} // CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = ""} // CHECK: aux.scalar.print(%[[ARG2]] : f32) {format = ""} -// CHECK: %[[ARG26:.*]] = aux.print(%[[ARG22]] : tensor<16xi32>) {format = ""} -> (tensor<16xi32>) -// CHECK: %[[ARG27:.*]] = aux.print(%[[ARG4]] : tensor<2x8xf32>) {format = ""} -> (tensor<2x8xf32>) +// CHECK: %[[ARG22:.*]] = aux.print(%[[ARG18]] : tensor<16xi32>) {format = ""} -> (tensor<16xi32>) +// CHECK: %[[ARG23:.*]] = aux.print(%[[ARG4]] : tensor<2x8xf32>) {format = ""} -> (tensor<2x8xf32>) // CHECK: aux.scalar.print(%[[ARG3]] : i32) {format = ""} -// CHECK: %[[ARG28:.*]] = tt.get_program_id x : i32 -// CHECK: %[[ARG29:.*]] = tt.get_program_id y : i32 -// CHECK: %[[ARG30:.*]] = tt.get_program_id z : i32 -// CHECK: aux.scalar.print(%[[ARG28]] : i32) {format = "pid ("} -// CHECK: aux.scalar.print(%[[ARG29]] : i32) {format = ", "} -// CHECK: aux.scalar.print(%[[ARG30]] : i32) {format = ", "} +// CHECK: %[[ARG24:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG25:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG26:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG24]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG25]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG26]] : i32) {format = ", "} // CHECK: aux.scalar.print {format = ") "} -// CHECK: %[[ARG31:.*]] = aux.print(%[[ARG26]] : tensor<16xi32>) {format = "arg1, arg2, arg4, arg3"} -> (tensor<16xi32>) +// CHECK: %[[ARG27:.*]] = aux.print(%[[ARG22]] : tensor<16xi32>) {format = "arg1, arg2, arg4, arg3"} -> (tensor<16xi32>) // CHECK: aux.scalar.print(%[[ARG2]] : f32) {format = "arg1, arg2, arg4, arg3"} -// CHECK: %[[ARG32:.*]] = aux.print(%[[ARG27]] : tensor<2x8xf32>) {format = "arg1, arg2, arg4, arg3"} -> (tensor<2x8xf32>) +// CHECK: %[[ARG28:.*]] = aux.print(%[[ARG23]] : tensor<2x8xf32>) {format = "arg1, arg2, arg4, arg3"} -> (tensor<2x8xf32>) // CHECK: aux.scalar.print(%[[ARG3]] : i32) {format = "arg1, arg2, arg4, arg3"} -// CHECK: %[[ARG33:.*]] = tt.get_program_id x : i32 -// CHECK: %[[ARG34:.*]] = tt.get_program_id y : i32 -// CHECK: %[[ARG35:.*]] = tt.get_program_id z : i32 -// CHECK: aux.scalar.print(%[[ARG33]] : i32) {format = "pid ("} -// CHECK: aux.scalar.print(%[[ARG34]] : i32) {format = ", "} -// CHECK: aux.scalar.print(%[[ARG35]] : i32) {format = ", "} +// CHECK: %[[ARG29:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG30:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG31:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG29]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG30]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG31]] : i32) {format = ", "} // CHECK: aux.scalar.print {format = ") "} -// CHECK: %[[ARG36:.*]] = aux.print(%[[ARG31]] : tensor<16xi32>) {format = "arg1, arg5, arg0, arg6, arg7, arg8"} -> (tensor<16xi32>) -// CHECK: %[[ARG37:.*]] = aux.print(%[[ARG9]] : tensor<16xi64>) {format = "arg1, arg5, arg0, arg6, arg7, arg8%p"} -> (tensor<16xi64>) -// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "arg1, arg5, arg0, arg6, arg7, arg8%p"} +// CHECK: %[[ARG32:.*]] = aux.print(%[[ARG27]] : tensor<16xi32>) {format = "arg1, arg5, arg0, arg6, arg7, arg8"} -> (tensor<16xi32>) +// CHECK: %[[ARG33:.*]] = aux.print(%[[ARG9]] : tensor<16xi64>) {format = "arg1, arg5, arg0, arg6, arg7, arg8%p"} -> (tensor<16xi64>) +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "arg1, arg5, arg0, arg6, arg7, arg8"} // CHECK: aux.scalar.print(%[[ARG10]] : i64) {format = "arg1, arg5, arg0, arg6, arg7, arg8%p"} -// CHECK: aux.scalar.print(%[[ARG7]] : i64) {format = "arg1, arg5, arg0, arg6, arg7, arg8%p"} -// CHECK: %[[ARG38:.*]] = aux.print(%[[ARG8]] : tensor<32xi64>) {format = "arg1, arg5, arg0, arg6, arg7, arg8%p"} -> (tensor<32xi64>) +// CHECK: aux.scalar.print(%[[ARG7]] : i64) {format = "arg1, arg5, arg0, arg6, arg7, arg8"} +// CHECK: %[[ARG34:.*]] = aux.print(%[[ARG8]] : tensor<32xi64>) {format = "arg1, arg5, arg0, arg6, arg7, arg8"} -> (tensor<32xi64>) func.func @print_scalar_and_tensor(%arg0: i32, %arg1: tensor<16xi32>, %arg2: f32, %arg3: i32, %arg4: tensor<2x8xf32>, %arg5: tensor<16x!tt.ptr>, %arg6: !tt.ptr, %arg7: i64, %arg8: tensor<32xi64>) { tt.print "" { hex = false } : %arg0, %arg1 : i32, tensor<16xi32> tt.print "arg0, arg1, arg2" { hex = false } : %arg0, %arg1, %arg2 : i32, tensor<16xi32>, f32 - tt.print "arg0, arg1, arg2" { hex = false } : %arg0, %arg1, %arg2 : i32, tensor<16xi32>, f32 tt.print "" { hex = false } : %arg0, %arg2, %arg1, %arg4, %arg3 : i32, f32, tensor<16xi32>, tensor<2x8xf32>, i32 tt.print "arg1, arg2, arg4, arg3" { hex = false } : %arg1, %arg2, %arg4, %arg3 : tensor<16xi32>, f32, tensor<2x8xf32>, i32 tt.print "arg1, arg5, arg0, arg6, arg7, arg8" { hex = false } : %arg1, %arg5, %arg0, %arg6, %arg7, %arg8 : tensor<16xi32>, tensor<16x!tt.ptr>, i32, !tt.ptr, i64, tensor<32xi64> return } +// ----- +// CHECK-LABEL: @print_scalar_and_tensor +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: tensor<16xi32>, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: tensor<2x8xf32>, %[[ARG5:.*]]: tensor<16x!tt.ptr>, %[[ARG6:.*]]: !tt.ptr) +// CHECK: %[[ARG7:.*]] = builtin.unrealized_conversion_cast %[[ARG5]] : tensor<16x!tt.ptr> to tensor<16xi64> +// CHECK: %[[ARG8:.*]] = builtin.unrealized_conversion_cast %[[ARG6]] : !tt.ptr to i64 +// CHECK: %[[ARG9:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG10:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG11:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG9]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG10]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG11]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "0x%08x"} +// CHECK: %[[ARG12:.*]] = aux.print(%[[ARG1]] : tensor<16xi32>) {format = "0x%08x"} -> (tensor<16xi32>) +// CHECK: %[[ARG13:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG14:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG15:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG13]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG14]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG15]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "arg0, arg1, arg20x%08x"} +// CHECK: %[[ARG16:.*]] = aux.print(%[[ARG12]] : tensor<16xi32>) {format = "arg0, arg1, arg20x%08x"} -> (tensor<16xi32>) +// CHECK: aux.scalar.print(%[[ARG2]] : f32) {format = "arg0, arg1, arg20x%08x"} +// CHECK: %[[ARG17:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG18:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG19:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG17]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG18]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG19]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "0x%08x"} +// CHECK: aux.scalar.print(%[[ARG2]] : f32) {format = "0x%08x"} +// CHECK: %[[ARG20:.*]] = aux.print(%[[ARG16]] : tensor<16xi32>) {format = "0x%08x"} -> (tensor<16xi32>) +// CHECK: %[[ARG21:.*]] = aux.print(%[[ARG4]] : tensor<2x8xf32>) {format = "0x%08x"} -> (tensor<2x8xf32>) +// CHECK: aux.scalar.print(%[[ARG3]] : i32) {format = "0x%08x"} +// CHECK: %[[ARG22:.*]] = tt.get_program_id x : i32 +// CHECK: %[[ARG23:.*]] = tt.get_program_id y : i32 +// CHECK: %[[ARG24:.*]] = tt.get_program_id z : i32 +// CHECK: aux.scalar.print(%[[ARG22]] : i32) {format = "pid ("} +// CHECK: aux.scalar.print(%[[ARG23]] : i32) {format = ", "} +// CHECK: aux.scalar.print(%[[ARG24]] : i32) {format = ", "} +// CHECK: aux.scalar.print {format = ") "} +// CHECK: %[[ARG25:.*]] = aux.print(%[[ARG20]] : tensor<16xi32>) {format = "arg1, arg5, arg0, arg6"} -> (tensor<16xi32>) +// CHECK: %[[ARG26:.*]] = aux.print(%[[ARG7]] : tensor<16xi64>) {format = "arg1, arg5, arg0, arg6%p"} -> (tensor<16xi64>) +// CHECK: aux.scalar.print(%[[ARG0]] : i32) {format = "arg1, arg5, arg0, arg6"} +// CHECK: aux.scalar.print(%[[ARG8]] : i64) {format = "arg1, arg5, arg0, arg6%p"} +func.func @print_scalar_and_tensor(%arg0: i32, %arg1: tensor<16xi32>, %arg2: f32, %arg3: i32, %arg4: tensor<2x8xf32>, %arg5: tensor<16x!tt.ptr>, %arg6: !tt.ptr) { + tt.print "" { hex = true } : %arg0, %arg1 : i32, tensor<16xi32> + tt.print "arg0, arg1, arg2" { hex = true } : %arg0, %arg1, %arg2 : i32, tensor<16xi32>, f32 + tt.print "" { hex = true } : %arg0, %arg2, %arg1, %arg4, %arg3 : i32, f32, tensor<16xi32>, tensor<2x8xf32>, i32 + tt.print "arg1, arg5, arg0, arg6" { hex = false } : %arg1, %arg5, %arg0, %arg6 : tensor<16xi32>, tensor<16x!tt.ptr>, i32, !tt.ptr + return +} + // ----- // CHECK-LABEL: @scan_add_2d_i32( // CHECK-SAME: %[[INPUT:.*]]: tensor<1x2048xi32>) -> tensor<1x2048xi32> { @@ -1370,10 +1592,8 @@ tt.func @scan_add_2d_i32(%arg0: tensor<1x2048xi32>) -> tensor<1x2048xi32> { // CHECK: %[[SCAN_INPUT:.*]] = tensor.extract_slice %[[INPUT]][0, 1] [1, 2047] [1, 1] : tensor<1x2048xi32> to tensor<1x2047xi32> // CHECK-NEXT: %[[SCAN_OUTPUT:.*]] = tensor.empty() : tensor<1x2047xi32> // CHECK-NEXT: %[[INIT:.*]] = tensor.extract_slice %[[INPUT]][0, 0] [1, 1] [1, 1] : tensor<1x2048xi32> to tensor<1x1xi32> - // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index - // CHECK-NEXT: %[[SHAPE:.*]] = tensor.from_elements %[[C1]] : tensor<1xindex> - // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.reshape %[[INIT]](%[[SHAPE]]) : (tensor<1x1xi32>, tensor<1xindex>) -> tensor<1xi32> - // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<1x2047xi32>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<1x2047xi32>, tensor<1xi32>) dimensions = [1] { + // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[\[0, 1]]}} : tensor<1x1xi32> into tensor<1xi32> + // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<1x2047xi32>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<1x2047xi32>, tensor<1xi32>) dimensions = [1] reverse = false { // CHECK-NEXT: ^bb0(%[[IN:.*]]: i32, %[[OUTPUT:.*]]: i32, %[[INIT:.*]]: i32): // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[IN]], %[[INIT]] : i32 // CHECK-NEXT: linalg_ext.yield %[[ADD]], %[[ADD]] : i32, i32 @@ -1396,10 +1616,8 @@ tt.func @scan_min_2d_f16(%arg0: tensor<1x2048xf16>) -> tensor<1x2048xf16> { // CHECK: %[[SCAN_INPUT:.*]] = tensor.extract_slice %[[INPUT]][0, 1] [1, 2047] [1, 1] : tensor<1x2048xf16> to tensor<1x2047xf16> // CHECK-NEXT: %[[SCAN_OUTPUT:.*]] = tensor.empty() : tensor<1x2047xf16> // CHECK-NEXT: %[[INIT:.*]] = tensor.extract_slice %[[INPUT]][0, 0] [1, 1] [1, 1] : tensor<1x2048xf16> to tensor<1x1xf16> - // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index - // CHECK-NEXT: %[[SHAPE:.*]] = tensor.from_elements %[[C1]] : tensor<1xindex> - // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.reshape %[[INIT]](%[[SHAPE]]) : (tensor<1x1xf16>, tensor<1xindex>) -> tensor<1xf16> - // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<1x2047xf16>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<1x2047xf16>, tensor<1xf16>) dimensions = [1] { + // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[\[0, 1]]}} : tensor<1x1xf16> into tensor<1xf16> + // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<1x2047xf16>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<1x2047xf16>, tensor<1xf16>) dimensions = [1] reverse = false { // CHECK-NEXT: ^bb0(%[[IN:.*]]: f16, %[[OUTPUT:.*]]: f16, %[[INIT:.*]]: f16): // CHECK-NEXT: %[[MIN:.*]] = arith.minimumf %[[IN]], %[[INIT]] : f16 // CHECK-NEXT: linalg_ext.yield %[[MIN]], %[[MIN]] : f16, f16 @@ -1435,9 +1653,8 @@ tt.func @scan_sub_1d_f32(%arg0: tensor<2048xf32>) -> tensor<2048xf32> { // CHECK: %[[SCAN_INPUT:.*]] = tensor.extract_slice %[[INPUT]][1] [2047] [1] : tensor<2048xf32> to tensor<2047xf32> // CHECK-NEXT: %[[SCAN_OUTPUT:.*]] = tensor.empty() : tensor<2047xf32> // CHECK-NEXT: %[[INIT:.*]] = tensor.extract_slice %[[INPUT]][0] [1] [1] : tensor<2048xf32> to tensor<1xf32> - // CHECK-NEXT: %[[SHAPE:.*]] = tensor.from_elements : tensor<0xindex> - // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.reshape %[[INIT]](%[[SHAPE]]) : (tensor<1xf32>, tensor<0xindex>) -> tensor - // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<2047xf32>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<2047xf32>, tensor) dimensions = [0] { + // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[]}} : tensor<1xf32> into tensor + // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<2047xf32>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<2047xf32>, tensor) dimensions = [0] reverse = false { // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUTPUT:.*]]: f32, %[[INIT:.*]]: f32): // CHECK-NEXT: %[[SUB:.*]] = arith.subf %[[IN]], %[[INIT]] : f32 // CHECK-NEXT: linalg_ext.yield %[[SUB]], %[[SUB]] : f32, f32 @@ -1452,3 +1669,315 @@ tt.func @scan_sub_1d_f32(%arg0: tensor<2048xf32>) -> tensor<2048xf32> { }) {axis = 0 : i32, reverse = false} : (tensor<2048xf32>) -> tensor<2048xf32> tt.return %0 : tensor<2048xf32> } + +// ----- +// CHECK-LABEL: @scan_add_2d_i32_reverse( +// CHECK-SAME: %[[INPUT:.*]]: tensor<1x2048xi32>) -> tensor<1x2048xi32> { +tt.func @scan_add_2d_i32_reverse(%arg0: tensor<1x2048xi32>) -> tensor<1x2048xi32> { + // CHECK: %[[SCAN_INPUT:.*]] = tensor.extract_slice %[[INPUT]][0, 0] [1, 2047] [1, 1] : tensor<1x2048xi32> to tensor<1x2047xi32> + // CHECK-NEXT: %[[SCAN_OUTPUT:.*]] = tensor.empty() : tensor<1x2047xi32> + // CHECK-NEXT: %[[INIT:.*]] = tensor.extract_slice %[[INPUT]][0, 2047] [1, 1] [1, 1] : tensor<1x2048xi32> to tensor<1x1xi32> + // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[\[0, 1]]}} : tensor<1x1xi32> into tensor<1xi32> + // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<1x2047xi32>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<1x2047xi32>, tensor<1xi32>) dimensions = [1] reverse = true { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: i32, %[[OUTPUT:.*]]: i32, %[[INIT:.*]]: i32): + // CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[IN]], %[[INIT]] : i32 + // CHECK-NEXT: linalg_ext.yield %[[ADD]], %[[ADD]] : i32, i32 + // CHECK-NEXT: } -> tensor<1x2047xi32>, tensor<1xi32> + // CHECK-NEXT: %[[INERT_SLICE:.*]] = tensor.insert_slice %[[SCAN]]#0 into %[[INPUT]][0, 0] [1, 2047] [1, 1] : tensor<1x2047xi32> into tensor<1x2048xi32> + // CHECK-NEXT: return %[[INERT_SLICE]] : tensor<1x2048xi32> + + %0 = "tt.scan" (%arg0) ({ + ^bb0(%arg1: i32, %arg2: i32): + %1 = arith.addi %arg1, %arg2 : i32 + tt.scan.return %1 : i32 + }) {axis = 1 : i32, reverse = true} : (tensor<1x2048xi32>) -> tensor<1x2048xi32> + tt.return %0 : tensor<1x2048xi32> +} + +// ----- +// CHECK-LABEL: @scan_min_2d_f16_reverse( +// CHECK-SAME: %[[INPUT:.*]]: tensor<1x2048xf16>) -> tensor<1x2048xf16> { +tt.func @scan_min_2d_f16_reverse(%arg0: tensor<1x2048xf16>) -> tensor<1x2048xf16> { + // CHECK: %[[SCAN_INPUT:.*]] = tensor.extract_slice %[[INPUT]][0, 0] [1, 2047] [1, 1] : tensor<1x2048xf16> to tensor<1x2047xf16> + // CHECK-NEXT: %[[SCAN_OUTPUT:.*]] = tensor.empty() : tensor<1x2047xf16> + // CHECK-NEXT: %[[INIT:.*]] = tensor.extract_slice %[[INPUT]][0, 2047] [1, 1] [1, 1] : tensor<1x2048xf16> to tensor<1x1xf16> + // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[\[0, 1]]}} : tensor<1x1xf16> into tensor<1xf16> + // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<1x2047xf16>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<1x2047xf16>, tensor<1xf16>) dimensions = [1] reverse = true { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f16, %[[OUTPUT:.*]]: f16, %[[INIT:.*]]: f16): + // CHECK-NEXT: %[[MIN:.*]] = arith.minimumf %[[IN]], %[[INIT]] : f16 + // CHECK-NEXT: linalg_ext.yield %[[MIN]], %[[MIN]] : f16, f16 + // CHECK-NEXT: } -> tensor<1x2047xf16>, tensor<1xf16> + // CHECK-NEXT: %[[INERT_SLICE:.*]] = tensor.insert_slice %[[SCAN]]#0 into %[[INPUT]][0, 0] [1, 2047] [1, 1] : tensor<1x2047xf16> into tensor<1x2048xf16> + // CHECK-NEXT: return %[[INERT_SLICE]] : tensor<1x2048xf16> + + %0 = "tt.scan" (%arg0) ({ + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.minimumf %arg1, %arg2 : f16 + tt.scan.return %1 : f16 + }) {axis = 1 : i32, reverse = true} : (tensor<1x2048xf16>) -> tensor<1x2048xf16> + tt.return %0 : tensor<1x2048xf16> +} + +// ----- +// CHECK-LABEL: @scan_sub_1d_f32_reverse( +// CHECK-SAME: %[[INPUT:.*]]: tensor<2048xf32>) -> tensor<2048xf32> { +tt.func @scan_sub_1d_f32_reverse(%arg0: tensor<2048xf32>) -> tensor<2048xf32> { + // CHECK: %[[SCAN_INPUT:.*]] = tensor.extract_slice %[[INPUT]][0] [2047] [1] : tensor<2048xf32> to tensor<2047xf32> + // CHECK-NEXT: %[[SCAN_OUTPUT:.*]] = tensor.empty() : tensor<2047xf32> + // CHECK-NEXT: %[[INIT:.*]] = tensor.extract_slice %[[INPUT]][2047] [1] [1] : tensor<2048xf32> to tensor<1xf32> + // CHECK-NEXT: %[[SCAN_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[]}} : tensor<1xf32> into tensor + // CHECK-NEXT: %[[SCAN:.*]]:2 = linalg_ext.scan ins(%[[SCAN_INPUT]] : tensor<2047xf32>) outs(%[[SCAN_OUTPUT]], %[[SCAN_INIT]] : tensor<2047xf32>, tensor) dimensions = [0] reverse = true { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUTPUT:.*]]: f32, %[[INIT:.*]]: f32): + // CHECK-NEXT: %[[SUB:.*]] = arith.subf %[[IN]], %[[INIT]] : f32 + // CHECK-NEXT: linalg_ext.yield %[[SUB]], %[[SUB]] : f32, f32 + // CHECK-NEXT: } -> tensor<2047xf32>, tensor + // CHECK-NEXT: %[[INERT_SLICE:.*]] = tensor.insert_slice %[[SCAN]]#0 into %[[INPUT]][0] [2047] [1] : tensor<2047xf32> into tensor<2048xf32> + // CHECK-NEXT: return %[[INERT_SLICE]] : tensor<2048xf32> + + %0 = "tt.scan" (%arg0) ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.subf %arg1, %arg2 : f32 + tt.scan.return %1 : f32 + }) {axis = 0 : i32, reverse = true} : (tensor<2048xf32>) -> tensor<2048xf32> + tt.return %0 : tensor<2048xf32> +} + +// ----- +// CHECK-LABEL: @scan_add_1d_size1_f32_reverse( +// CHECK-SAME: %[[INPUT:.*]]: tensor<1xf32> +tt.func @scan_add_1d_size1_f32_reverse(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: return %[[INPUT]] : tensor<1xf32> + %0 = "tt.scan" (%arg0) ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.scan.return %1 : f32 + }) {axis = 0 : i32, reverse = true} : (tensor<1xf32>) -> tensor<1xf32> + tt.return %0 : tensor<1xf32> +} + +// ----- +tt.func @tt_mulhiui_scalar_i32(%arg0: i32, %arg1: i32) { + // CHECK: math_ext.mulhiui + %0 = tt.mulhiui %arg0, %arg1 : i32 + tt.return +} + +// ----- +tt.func @tt_mulhiui_vector_i32(%arg0: tensor<16x16xi32>, %arg1: tensor<16x16xi32>) { + // CHECK: math_ext.mulhiui + %0 = tt.mulhiui %arg0, %arg1 : tensor<16x16xi32> + tt.return +} + +// ----- +// CHECK-LABEL: @join_int8 +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x8xi8>, %[[ARG1:.*]]: tensor<2x8xi8> +tt.func @join_int8(%arg0: tensor<2x8xi8>, %arg1: tensor<2x8xi8>) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x8x2xi8> + // CHECK: %[[INSET1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0, 0] [2, 8, 1] [1, 1, 1] : tensor<2x8xi8> into tensor<2x8x2xi8> + // CHECK: %[[INSET2:.*]] = tensor.insert_slice %[[ARG1]] into %[[INSET1]][0, 0, 1] [2, 8, 1] [1, 1, 1] : tensor<2x8xi8> into tensor<2x8x2xi8> + %0 = tt.join %arg0, %arg1 : tensor<2x8xi8> -> tensor<2x8x2xi8> + tt.return +} + +// ----- +// CHECK-LABEL: @join_float32 +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x2x8xf32>, %[[ARG1:.*]]: tensor<4x2x8xf32> +tt.func @join_float32(%arg0: tensor<4x2x8xf32>, %arg1: tensor<4x2x8xf32>) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<4x2x8x2xf32> + // CHECK: %[[INSET1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0, 0, 0] [4, 2, 8, 1] [1, 1, 1, 1] : tensor<4x2x8xf32> into tensor<4x2x8x2xf32> + // CHECK: %[[INSET2:.*]] = tensor.insert_slice %[[ARG1]] into %[[INSET1]][0, 0, 0, 1] [4, 2, 8, 1] [1, 1, 1, 1] : tensor<4x2x8xf32> into tensor<4x2x8x2xf32> + %0 = tt.join %arg0, %arg1 : tensor<4x2x8xf32> -> tensor<4x2x8x2xf32> + tt.return +} + +// ----- +// CHECK-LABEL: @join_scalar +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor +tt.func @join_scalar(%arg0: tensor, %arg1: tensor) { + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2xf32> + // CHECK: %[[INSET1:.*]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0] [1] [1] : tensor into tensor<2xf32> + // CHECK: %[[INSET2:.*]] = tensor.insert_slice %[[ARG1]] into %[[INSET1]][1] [1] [1] : tensor into tensor<2xf32> + %0 = tt.join %arg0, %arg1 : tensor -> tensor<2xf32> + tt.return +} + +// ----- +// CHECK-LABEL: @split_int8 +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x8x2xi8> +tt.func @split_int8(%arg0: tensor<2x8x2xi8>) { + // CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [2, 8, 1] [1, 1, 1] : tensor<2x8x2xi8> to tensor<2x8xi8> + // CHECK: %[[SLICE2:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 1] [2, 8, 1] [1, 1, 1] : tensor<2x8x2xi8> to tensor<2x8xi8> + %0, %1 = tt.split %arg0 : tensor<2x8x2xi8> -> tensor<2x8xi8> + tt.return +} + +// ----- +// CHECK-LABEL: @split_float32 +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x2x8x2xf32> +tt.func @split_float32(%arg0: tensor<4x2x8x2xf32>) { + // CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [4, 2, 8, 1] [1, 1, 1, 1] : tensor<4x2x8x2xf32> to tensor<4x2x8xf32> + // CHECK: %[[SLICE2:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 1] [4, 2, 8, 1] [1, 1, 1, 1] : tensor<4x2x8x2xf32> to tensor<4x2x8xf32> + %0, %1 = tt.split %arg0 : tensor<4x2x8x2xf32> -> tensor<4x2x8xf32> + tt.return +} + +// ----- +// CHECK-LABEL: @split_one_dim +// CHECK-SAME: %[[ARG0:.*]]: tensor<2xf32> +tt.func @split_one_dim(%arg0: tensor<2xf32>) { + // CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG0]][0] [1] [1] : tensor<2xf32> to tensor + // CHECK: %[[SLICE2:.*]] = tensor.extract_slice %[[ARG0]][1] [1] [1] : tensor<2xf32> to tensor + %0, %1 = tt.split %arg0 : tensor<2xf32> -> tensor + tt.return +} + +// ----- +// CHECK-LABEL: @tt_precise_sqrt_vector_f16 +tt.func @tt_precise_sqrt_vector_f16(%arg0: tensor<128xf16>) { + // CHECK: tensor.empty + // CHECK: linalg.map { math.sqrt + %0 = tt.precise_sqrt %arg0 : tensor<128xf16> + tt.return +} + +// ----- +// CHECK-LABEL: @tt_precise_sqrt_vector_f32 +tt.func @tt_precise_sqrt_vector_f32(%arg0: tensor<128xf32>) { + // CHECK: tensor.empty + // CHECK: linalg.map { math.sqrt + %0 = tt.precise_sqrt %arg0 : tensor<128xf32> + tt.return +} + +// ----- +// CHECK-LABEL: @tt_precise_divf_vector_f16 +tt.func @tt_precise_divf_vector_f16(%arg0: tensor<128xf16>, %arg1: tensor<128xf16>) { + // CHECK: tensor.empty + // CHECK: linalg.map { arith.divf } + %0 = tt.precise_divf %arg0, %arg1 : tensor<128xf16> + tt.return +} + +// ----- +// CHECK-LABEL: @tt_precise_divf_vector_f32 +tt.func @tt_precise_divf_vector_f32(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) { + // CHECK: tensor.empty + // CHECK: linalg.map { arith.divf } + %0 = tt.precise_divf %arg0, %arg1 : tensor<128xf32> + tt.return +} + +// ----- +// CHECK-LABEL: @clampf_propagateNan_all_f32( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32>, %[[ARG2:.*]]: tensor<32xf32> +tt.func @clampf_propagateNan_all_f32(%x: tensor<32xf32>, %min: tensor<32xf32>, %max: tensor<32xf32>) -> tensor<32xf32> { + // CHECK: %[[MAPPED:.*]] = linalg.map { arith.maximumf } ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>) + // CHECK: linalg.map { arith.minimumf } ins(%[[MAPPED]], %[[ARG2]] : tensor<32xf32>, tensor<32xf32>) + %0 = tt.clampf %x, %min, %max, propagateNan = all : tensor<32xf32> + tt.return %0 : tensor<32xf32> +} + +// ----- +// CHECK-LABEL: @clampf_propagateNan_none_f32( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32>, %[[ARG2:.*]]: tensor<32xf32> +tt.func @clampf_propagateNan_none_f32(%x: tensor<32xf32>, %min: tensor<32xf32>, %max: tensor<32xf32>) -> tensor<32xf32> { + // CHECK: %[[MAPPED]] = linalg.map { arith.maxnumf } ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>) + // CHECK: linalg.map { arith.minnumf } ins(%[[MAPPED]], %[[ARG2]] : tensor<32xf32>, tensor<32xf32>) + %0 = tt.clampf %x, %min, %max, propagateNan = none : tensor<32xf32> + tt.return %0 : tensor<32xf32> +} + +// ----- +// CHECK-LABEL: @clampf_propagateNan_all_f16( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf16>, %[[ARG1:.*]]: tensor<32xf16>, %[[ARG2:.*]]: tensor<32xf16> +tt.func @clampf_propagateNan_all_f16(%x: tensor<32xf16>, %min: tensor<32xf16>, %max: tensor<32xf16>) -> tensor<32xf16> { + // CHECK: %[[MAPPED]] = linalg.map { arith.maximumf } ins(%[[ARG0]], %[[ARG1]] : tensor<32xf16>, tensor<32xf16>) + // CHECK: linalg.map { arith.minimumf } ins(%[[MAPPED]], %[[ARG2]] : tensor<32xf16>, tensor<32xf16>) + %0 = tt.clampf %x, %min, %max, propagateNan = all : tensor<32xf16> + tt.return %0 : tensor<32xf16> +} + +// ----- +// CHECK-LABEL: @clampf_propagateNan_none_f16( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf16>, %[[ARG1:.*]]: tensor<32xf16>, %[[ARG2:.*]]: tensor<32xf16> +tt.func @clampf_propagateNan_none_f16(%x: tensor<32xf16>, %min: tensor<32xf16>, %max: tensor<32xf16>) -> tensor<32xf16> { + // CHECK: %[[MAPPED]] = linalg.map { arith.maxnumf } ins(%[[ARG0]], %[[ARG1]] : tensor<32xf16>, tensor<32xf16>) + // CHECK: linalg.map { arith.minnumf } ins(%[[MAPPED]], %[[ARG2]] : tensor<32xf16>, tensor<32xf16>) + %0 = tt.clampf %x, %min, %max, propagateNan = none : tensor<32xf16> + tt.return %0 : tensor<32xf16> +} + +// ----- +// CHECK-LABEL: @histogram_i32 +// CHECK-SAME: %[[ARG0:.*]]: tensor<8xi32>) +// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[C1_I32:.*]] = arith.constant 1 : i32 +// CHECK: %[[C2_I32:.*]] = arith.constant 2 : i32 +// CHECK: %[[ARG1:.*]] = arith.subi %[[C2_I32]], %[[C1_I32]] : i32 +// CHECK: %[[ARG2:.*]] = tensor.empty() : tensor<2xi32> +// CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[ARG3:.*]] = linalg.fill ins(%[[C0_I32_0]] : i32) outs(%[[ARG2]] : tensor<2xi32>) -> tensor<2xi32> +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[ARG4:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C8]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG3]]) -> (tensor<2xi32>) { +// CHECK: %[[ARG7:.*]] = tensor.extract %[[ARG0]][%[[ARG5]]] : tensor<8xi32> +// CHECK: %[[ARG8:.*]] = arith.cmpi sle, %[[C0_I32]], %[[ARG7]] : i32 +// CHECK: %[[ARG9:.*]] = arith.cmpi sge, %[[ARG1]], %[[ARG7]] : i32 +// CHECK: %[[ARG10:.*]] = arith.andi %[[ARG8]], %[[ARG9]] : i1 +// CHECK: %[[ARG11:.*]] = scf.if %[[ARG10]] -> (tensor<2xi32>) { +// CHECK: %[[ARG12:.*]] = arith.subi %[[ARG7]], %[[C0_I32]] : i32 +// CHECK: %[[ARG13:.*]] = arith.index_cast %[[ARG12]] : i32 to index +// CHECK: %[[ARG14:.*]] = tensor.extract %[[ARG6]][%[[ARG13]]] : tensor<2xi32> +// CHECK: %[[C1_I32_2:.*]] = arith.constant 1 : i32 +// CHECK: %[[ARG15:.*]] = arith.addi %[[ARG14]], %[[C1_I32_2]] : i32 +// CHECK: %[[ARG16:.*]] = tensor.insert %[[ARG15]] into %[[ARG6]][%[[ARG13]]] : tensor<2xi32> +// CHECK: scf.yield %[[ARG16]] : tensor<2xi32> +// CHECK: } else { +// CHECK: scf.yield %[[ARG6]] : tensor<2xi32> +// CHECK: } +// CHECK: scf.yield %[[ARG11]] : tensor<2xi32> +// CHECK: } +// CHECK: return +tt.func @histogram_i32(%0: tensor<8xi32>) { + %1 = tt.histogram %0 : tensor<8xi32> -> tensor<2xi32> + tt.return +} + +// ----- +// CHECK-LABEL: @histogram_i64 +// CHECK-SAME: %[[ARG0:.*]]: tensor<128xi64>) +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK: %[[ARG1:.*]] = arith.subi %[[C32_I64]], %[[C1_I64]] : i64 +// CHECK: %[[ARG2:.*]] = tensor.empty() : tensor<32xi64> +// CHECK: %[[C0_I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[ARG3:.*]] = linalg.fill ins(%[[C0_I64_0]] : i64) outs(%[[ARG2]] : tensor<32xi64>) -> tensor<32xi64> +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[ARG4:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C128]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG3]]) -> (tensor<32xi64>) { +// CHECK: %[[ARG7:.*]] = tensor.extract %[[ARG0]][%[[ARG5]]] : tensor<128xi64> +// CHECK: %[[ARG8:.*]] = arith.cmpi sle, %[[C0_I64]], %[[ARG7]] : i64 +// CHECK: %[[ARG9:.*]] = arith.cmpi sge, %[[ARG1]], %[[ARG7]] : i64 +// CHECK: %[[ARG10:.*]] = arith.andi %[[ARG8]], %[[ARG9]] : i1 +// CHECK: %[[ARG11:.*]] = scf.if %[[ARG10]] -> (tensor<32xi64>) { +// CHECK: %[[ARG12:.*]] = arith.subi %[[ARG7]], %[[C0_I64]] : i64 +// CHECK: %[[ARG13:.*]] = arith.index_cast %[[ARG12]] : i64 to index +// CHECK: %[[ARG14:.*]] = tensor.extract %[[ARG6]][%[[ARG13]]] : tensor<32xi64> +// CHECK: %[[C1_I64_2:.*]] = arith.constant 1 : i64 +// CHECK: %[[ARG15:.*]] = arith.addi %[[ARG14]], %[[C1_I64_2]] : i64 +// CHECK: %[[ARG16:.*]] = tensor.insert %[[ARG15]] into %[[ARG6]][%[[ARG13]]] : tensor<32xi64> +// CHECK: scf.yield %[[ARG16]] : tensor<32xi64> +// CHECK: } else { +// CHECK: scf.yield %[[ARG6]] : tensor<32xi64> +// CHECK: } +// CHECK: scf.yield %[[ARG11]] : tensor<32xi64> +// CHECK: } +// CHECK: return +tt.func @histogram_i64(%0: tensor<128xi64>) { + %1 = tt.histogram %0 : tensor<128xi64> -> tensor<32xi64> + tt.return +} diff --git a/test/Dialect/LinalgExt/invalid.mlir b/test/Dialect/LinalgExt/invalid.mlir index b2901a9..3945044 100644 --- a/test/Dialect/LinalgExt/invalid.mlir +++ b/test/Dialect/LinalgExt/invalid.mlir @@ -137,7 +137,6 @@ func.func @batch_conv_2d_nhwc_fhwc_invalid_dtype_in_strides(%input: tensor) -> tensor<2x64xi32> { %c0 = arith.constant 0 : i32 %c128 = arith.constant 128 : i32 @@ -147,7 +146,6 @@ func.func @make_range_output_rank_invalid(%arg0: tensor<2x64xi32>) -> tensor<2x6 } // ----- -// CHECK: linalg_ext.make_range func.func @make_range_start_end_invalid(%arg0: tensor<128xi32>) -> tensor<128xi32> { %c0 = arith.constant 0 : i32 %c128 = arith.constant 128 : i32 @@ -157,7 +155,6 @@ func.func @make_range_start_end_invalid(%arg0: tensor<128xi32>) -> tensor<128xi3 } // ----- -// CHECK: linalg_ext.make_range func.func @make_range_output_shape_mismatch(%arg0: tensor<129xi32>) -> tensor<129xi32> { %c0 = arith.constant 0 : i32 %c128 = arith.constant 128 : i32 @@ -167,7 +164,6 @@ func.func @make_range_output_shape_mismatch(%arg0: tensor<129xi32>) -> tensor<12 } // ----- -// CHECK: linalg_ext.make_range func.func @make_range_result_type_invalid(%arg0: tensor<128xf32>) -> tensor<128xf32> { %c2 = arith.constant 2 : i32 %c130 = arith.constant 130 : i32 @@ -200,8 +196,7 @@ func.func @scatter_extra_outputs( %init : tensor) -> (tensor, tensor) { // expected-error @+1 {{expected the number of tensor results (2) to be equal to the number of output tensors (1)}} %0, %1 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -217,8 +212,7 @@ func.func @scatter_mistmatch_dim_map_entries( %init : tensor) -> tensor { // expected-error @+1 {{invalid number of dimension map entries}} %0 = linalg_ext.scatter dimension_map = [0, 1] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -234,8 +228,7 @@ func.func @scatter_nd_batch_exceed_dim( %init : tensor<4x4xf32>) -> tensor<4x4xf32> { // expected-error @+1 {{indexed shape of update value dim#2 exceeds init value at dim#0 8 .vs. 4}} %0 = linalg_ext.scatter dimension_map = [0, 1] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor<1x1x8x8xf32>, tensor<1x1x2xi32>) outs(%init : tensor<4x4xf32>) { ^bb0(%arg1: f32, %arg2: f32): @@ -251,8 +244,7 @@ func.func @scatter_duplicate_dim_map_entries( %init : tensor) -> tensor { // expected-error @+1 {{dimension map is invalid}} %0 = linalg_ext.scatter dimension_map = [1, 1] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -268,8 +260,7 @@ func.func @scatter_invalid_dim_map_entries( %init : tensor) -> tensor { // expected-error @+1 {{dimension map is invalid}} %0 = linalg_ext.scatter dimension_map = [2] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -285,8 +276,7 @@ func.func @scatter_output_type_mismatch( %init : tensor) -> tensor<4x?xf32> { // expected-error @+1 {{expected type of operand #2 ('tensor') to match type of corresponding result ('tensor<4x?xf32>')}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -302,8 +292,7 @@ func.func @scatter_dim_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at batch dim}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor<48x1xi32>) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -319,8 +308,7 @@ func.func @scatter_dim_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at batch dim}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor<64x?x1xf32>, tensor<48x1xi32>) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -336,8 +324,7 @@ func.func @scatter_dim_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{op update value rank mismatch the rank of the init value}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -353,8 +340,7 @@ func.func @scatter_dim_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{op indexed shape of update value dim#2 exceeds init value at dim#1}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -370,8 +356,7 @@ func.func @scatter_region_type_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{expected region to have scalar argument of integer or float types}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: index, %arg2: index): @@ -388,8 +373,7 @@ func.func @scatter_region_type_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i32): @@ -406,8 +390,7 @@ func.func @scatter_region_type_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of init value 'i32'}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -425,8 +408,7 @@ func.func @scatter_region_type_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -443,8 +425,7 @@ func.func @scatter_region_type_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{expected region to have two arguments}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64): @@ -460,8 +441,7 @@ func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %init : tensor) -> tensor { %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -478,8 +458,7 @@ func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %init : tensor) -> tensor { %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -497,8 +476,7 @@ func.func @scatter_index_depth_dynamic( %init : tensor) -> tensor { // expected-error @+1 {{expected index depth is static}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -515,8 +493,7 @@ func.func @scatter_init_rank_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{expected init value to be at least rank 1}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -533,8 +510,7 @@ func.func @scatter_init_rank_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{expected update value to be at least rank 2}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -551,8 +527,7 @@ func.func @scatter_mask_shape_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of mask and update value at batch dim}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices, %mask : tensor, tensor, tensor<8xi1>) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -569,8 +544,7 @@ func.func @scatter_mask_type_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{expected mask to be of i1 element type and batch matched init}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices, %mask : tensor, tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -587,8 +561,7 @@ func.func @scatter_indice_type_mismatch( %init : tensor) -> tensor { // expected-error @+1 {{expected indices to be of rank 2 of i8/i16/i32/i64 element type}} %0 = linalg_ext.scatter dimension_map = [0] - ranged_data(false) - overlap_window(false) + ranged_data(false) overlap_window(false) signed_indice(false) ins(%update, %indices, %mask : tensor, tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -605,7 +578,7 @@ func.func @gather_extra_outputs( %input : tensor) -> (tensor, tensor) { // expected-error @+1 {{expected the number of tensor results (2) to be equal to the number of output tensors (1)}} %0, %1 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -621,7 +594,7 @@ func.func @gather_nd_batch_exceed_dim( %input : tensor<4x4xf32>) -> tensor<1x1x8x8xf32> { // expected-error @+1 {{indexed shape of init value dim#2 exceeds input value at dim#0 8 .vs. 4}} %0 = linalg_ext.gather dimension_map = [0, 1] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor<4x4xf32>, tensor<1x1x2xi32>) outs(%init : tensor<1x1x8x8xf32>) { ^bb0(%arg1: f32, %arg2: f32): @@ -637,7 +610,7 @@ func.func @gather_mistmatch_dim_map_entries( %input : tensor) -> tensor { // expected-error @+1 {{invalid number of dimension map entries}} %0 = linalg_ext.gather dimension_map = [0, 1] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -653,7 +626,7 @@ func.func @gather_duplicate_dim_map_entries( %input : tensor) -> tensor { // expected-error @+1 {{dimension map is invalid}} %0 = linalg_ext.gather dimension_map = [1, 1] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -669,7 +642,7 @@ func.func @gather_invalid_dim_map_entries( %input : tensor) -> tensor { // expected-error @+1 {{dimension map is invalid}} %0 = linalg_ext.gather dimension_map = [2] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -685,7 +658,7 @@ func.func @gather_dim_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and init value at batch dim}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor<48x1xi32>) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -701,7 +674,7 @@ func.func @gather_dim_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{op init value rank exceeds the rank of the input value}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -717,7 +690,7 @@ func.func @gather_dim_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{op indexed shape of init value dim#2 exceeds input value at dim#1}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -733,7 +706,7 @@ func.func @gather_region_type_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{expected region to have scalar argument of integer or float types}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: index, %arg2: index): @@ -750,7 +723,7 @@ func.func @gather_region_type_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of init value 'i32'}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i32): @@ -767,7 +740,7 @@ func.func @gather_region_type_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of input value 'i32'}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -783,7 +756,7 @@ func.func @gather_yield_mismatch( %init : tensor, %indices : tensor, %input : tensor) -> tensor { %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -801,7 +774,7 @@ func.func @gather_index_depth_dynamic( %input : tensor) -> tensor { // expected-error @+1 {{expected index depth is static}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -818,7 +791,7 @@ func.func @gather_input_rank_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{expected input value to be at least rank 1}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -835,7 +808,7 @@ func.func @gather_input_rank_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{expected init value to be at least rank 2}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -852,7 +825,7 @@ func.func @gather_mask_shape_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of mask and init value at batch dim}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices, %mask : tensor, tensor, tensor<8xi1>) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -869,7 +842,7 @@ func.func @gather_mask_type_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{expected mask to be of i1 element type and batch matched init}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices, %mask : tensor, tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -886,7 +859,7 @@ func.func @gather_indice_type_mismatch( %input : tensor) -> tensor { // expected-error @+1 {{expected indices to be of rank 2 of i8/i16/i32/i64 element type}} %0 = linalg_ext.gather dimension_map = [0] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices, %mask : tensor, tensor, tensor) outs(%init : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -988,7 +961,6 @@ func.func @pad_pvalue_type_mismatch(%input : tensor<4x4xf32>, %init : tensor<6x8 } // ----- - func.func @scan_unmatched_output_and_init_num(%input: tensor<16x32x64xf32>, %output0: tensor<16x32x64xf32>, %output1: tensor<16x32x64xf32>, @@ -998,6 +970,7 @@ func.func @scan_unmatched_output_and_init_num(%input: tensor<16x32x64xf32>, ins(%input: tensor<16x32x64xf32>) outs(%output0, %output1, %init: tensor<16x32x64xf32>, tensor<16x32x64xf32>, tensor<16x64xf32>) dimensions = [1] + reverse = false { ^bb0(%in: f32, %out0: f32, %out1: f32, %ini: f32): %0 = arith.addf %ini, %in: f32 @@ -1007,7 +980,6 @@ func.func @scan_unmatched_output_and_init_num(%input: tensor<16x32x64xf32>, } // ----- - func.func @scan_unmatched_dim(%input: tensor<16x32x64xf32>, %output: tensor<16x64xf32>, %init: tensor<16xf32>) { @@ -1016,6 +988,7 @@ func.func @scan_unmatched_dim(%input: tensor<16x32x64xf32>, ins(%input: tensor<16x32x64xf32>) outs(%output, %init: tensor<16x64xf32>, tensor<16xf32>) dimensions = [1] + reverse = false { ^bb0(%in: f32, %out: f32, %ini: f32): %0 = arith.addf %ini, %in: f32 @@ -1025,7 +998,6 @@ func.func @scan_unmatched_dim(%input: tensor<16x32x64xf32>, } // ----- - func.func @scan_out_of_range(%input: tensor<16x32x64xf32>, %output: tensor<16x32x64xf32>, %init: tensor<32x64xf32>) { @@ -1034,6 +1006,7 @@ func.func @scan_out_of_range(%input: tensor<16x32x64xf32>, ins(%input: tensor<16x32x64xf32>) outs(%output, %init: tensor<16x32x64xf32>, tensor<32x64xf32>) dimensions = [4] + reverse = false { ^bb0(%in: f32, %out: f32, %ini: f32): %0 = arith.addf %ini, %in: f32 @@ -1043,7 +1016,6 @@ func.func @scan_out_of_range(%input: tensor<16x32x64xf32>, } // ----- - func.func @scan_unmatched_input_and_output_shape(%input0: tensor<16x32x64xf32>, %input1: tensor<16x32x64xf32>, %output0: tensor<32x64xf32>, @@ -1055,6 +1027,7 @@ func.func @scan_unmatched_input_and_output_shape(%input0: tensor<16x32x64xf32>, ins(%input0, %input1: tensor<16x32x64xf32>, tensor<16x32x64xf32>) outs(%output0, %output1, %init0, %init1: tensor<32x64xf32>, tensor<32x64xf32>, tensor<32x64xf32>, tensor<32x64xf32>) dimensions = [0] + reverse = false { ^bb0(%in0: f32, %in1: f32, %out0: f32, %out1: f32, %ini0: f32, %ini1: f32): %0 = arith.addf %ini0, %in0: f32 @@ -1065,7 +1038,6 @@ func.func @scan_unmatched_input_and_output_shape(%input0: tensor<16x32x64xf32>, } // ----- - func.func @scan_unmatched_inputs_shape(%input0: tensor<16x32x64xf32>, %input1: tensor<32x64xf32>, %output0: tensor<16x32x64xf32>, @@ -1077,6 +1049,7 @@ func.func @scan_unmatched_inputs_shape(%input0: tensor<16x32x64xf32>, ins(%input0, %input1: tensor<16x32x64xf32>, tensor<32x64xf32>) outs(%output0, %output1, %init0, %init1: tensor<16x32x64xf32>, tensor<16x32x64xf32>, tensor<32x64xf32>, tensor<32x64xf32>) dimensions = [0] + reverse = false { ^bb0(%in0: f32, %in1: f32, %out0: f32, %out1: f32, %ini0: f32, %ini1: f32): %0 = arith.addf %ini0, %in0: f32 @@ -1087,7 +1060,6 @@ func.func @scan_unmatched_inputs_shape(%input0: tensor<16x32x64xf32>, } // ----- - func.func @scan_unmatched_outputs_shape(%input0: tensor<16x32x64xf32>, %input1: tensor<16x32x64xf32>, %output0: tensor<16x32x64xf32>, @@ -1099,6 +1071,7 @@ func.func @scan_unmatched_outputs_shape(%input0: tensor<16x32x64xf32>, ins(%input0, %input1: tensor<16x32x64xf32>, tensor<16x32x64xf32>) outs(%output0, %output1, %init0, %init1: tensor<16x32x64xf32>, tensor<32x64xf32>, tensor<32x64xf32>, tensor<32x64xf32>) dimensions = [0] + reverse = false { ^bb0(%in0: f32, %in1: f32, %out0: f32, %out1: f32, %ini0: f32, %ini1: f32): %0 = arith.addf %ini0, %in0: f32 @@ -1109,7 +1082,6 @@ func.func @scan_unmatched_outputs_shape(%input0: tensor<16x32x64xf32>, } // ----- - func.func @scan_unmatched_inits_shape(%input0: tensor<16x32x64xf32>, %input1: tensor<16x32x64xf32>, %output0: tensor<16x32x64xf32>, @@ -1121,6 +1093,7 @@ func.func @scan_unmatched_inits_shape(%input0: tensor<16x32x64xf32>, ins(%input0, %input1: tensor<16x32x64xf32>, tensor<16x32x64xf32>) outs(%output0, %output1, %init0, %init1: tensor<16x32x64xf32>, tensor<16x32x64xf32>, tensor<32x64xf32>, tensor<16x64xf32>) dimensions = [0] + reverse = false { ^bb0(%in0: f32, %in1: f32, %out0: f32, %out1: f32, %ini0: f32, %ini1: f32): %0 = arith.addf %ini0, %in0: f32 @@ -1131,7 +1104,6 @@ func.func @scan_unmatched_inits_shape(%input0: tensor<16x32x64xf32>, } // ----- - func.func @scan_unexpected_inits_shape(%input0: tensor<16x32x64xf32>, %input1: tensor<16x32x64xf32>, %output0: tensor<16x32x64xf32>, @@ -1143,6 +1115,7 @@ func.func @scan_unexpected_inits_shape(%input0: tensor<16x32x64xf32>, ins(%input0, %input1: tensor<16x32x64xf32>, tensor<16x32x64xf32>) outs(%output0, %output1, %init0, %init1: tensor<16x32x64xf32>, tensor<16x32x64xf32>, tensor<16x64xf32>, tensor<16x64xf32>) dimensions = [0] + reverse = false { ^bb0(%in0: f32, %in1: f32, %out0: f32, %out1: f32, %ini0: f32, %ini1: f32): %0 = arith.addf %ini0, %in0: f32 @@ -1153,7 +1126,6 @@ func.func @scan_unexpected_inits_shape(%input0: tensor<16x32x64xf32>, } // ----- - func.func @scan_unmatched_block_args_num(%input: tensor<16xf32>, %output: tensor<16xf32>, %init: tensor) { @@ -1162,6 +1134,7 @@ func.func @scan_unmatched_block_args_num(%input: tensor<16xf32>, ins(%input: tensor<16xf32>) outs(%output, %init: tensor<16xf32>, tensor) dimensions = [0] + reverse = false { ^bb0(%in: f32, %out0: f32, %out1: f32, %ini: f32): %0 = arith.addf %ini, %in: f32 @@ -1171,7 +1144,6 @@ func.func @scan_unmatched_block_args_num(%input: tensor<16xf32>, } // ----- - func.func @scan_unmatched_element_type(%input: tensor<16xf32>, %output: tensor<16xf32>, %init: tensor) { @@ -1180,6 +1152,7 @@ func.func @scan_unmatched_element_type(%input: tensor<16xf32>, ins(%input: tensor<16xf32>) outs(%output, %init: tensor<16xf32>, tensor) dimensions = [0] + reverse = false { ^bb0(%in: i32, %out: i32, %ini: i32): %0 = arith.addi %ini, %in: i32 @@ -1189,7 +1162,6 @@ func.func @scan_unmatched_element_type(%input: tensor<16xf32>, } // ----- - func.func @scan_multi_operands(%input0: tensor<16x32x64xf32>, %input1: tensor<16x32x64xf32>, %output0: tensor<16x32x64xf32>, @@ -1201,6 +1173,7 @@ func.func @scan_multi_operands(%input0: tensor<16x32x64xf32>, ins(%input0, %input1: tensor<16x32x64xf32>, tensor<16x32x64xf32>) outs(%output0, %output1, %init0, %init1: tensor<16x32x64xf32>, tensor<16x32x64xf32>, tensor<32x64xf32>, tensor<32x64xf32>) dimensions = [0, 1] + reverse = false { ^bb0(%in0: f32, %in1: f32, %out0: f32, %out1: f32, %ini0: f32, %ini1: f32): %0 = arith.addf %ini0, %in0: f32 @@ -1209,3 +1182,21 @@ func.func @scan_multi_operands(%input0: tensor<16x32x64xf32>, } func.return } + +// ----- +func.func @scalar_libdevice_call_input_invalid(%arg0: tensor) -> f32 { + // expected-error @+1 {{'linalg_ext.scalar_libdevice_call' op expects all input types are scalar type.}} + %libdevicecall = linalg_ext.scalar_libdevice_call + ins(%arg0 : tensor) + symbol = "__cn_scalar_abs_f32" -> f32 + return %libdevicecall : f32 +} + +// ----- +func.func @scalar_libdevice_call_result_invalid(%arg0: f32) -> tensor { + // expected-error @+1 {{'linalg_ext.scalar_libdevice_call' op expects the result type is scalar type.}} + %libdevicecall = linalg_ext.scalar_libdevice_call + ins(%arg0, %arg0 : f32, f32) + symbol = "__cn_scalar_add_f32" -> tensor + return %libdevicecall : tensor +} diff --git a/test/Dialect/LinalgExt/ops.mlir b/test/Dialect/LinalgExt/ops.mlir index 720c2dd..a8a5eb0 100644 --- a/test/Dialect/LinalgExt/ops.mlir +++ b/test/Dialect/LinalgExt/ops.mlir @@ -24,18 +24,6 @@ func.func @transpose_memref(%input: memref<16x32x64xf32>, // ----- -func.func @transpose_memref(%input: memref<16x32x64xf32>, - %init: memref<16x32x64xf32>) { - linalg.transpose - ins(%input:memref<16x32x64xf32>) - outs(%init:memref<16x32x64xf32>) - permutation = [0, 1, 2] - func.return -} -// CHECK-LABEL: func @transpose_memref - -// ----- - func.func @map(%lhs: tensor<16x32x64xf32>, %rhs: tensor<16x32x64xf32>, %init: tensor<16x32x64xf32>) { linalg.map @@ -187,28 +175,13 @@ func.func @im2col_memref(%input: memref<32x16x16x256xf32, 101>, %init: memref<32 return %init : memref<32x14x14x3x3x256xf32, 101> } -// ----- -// CHECK: linalg_ext.scatter -func.func @scatter_tensor(%A : tensor<4x1xi32>, %B: tensor<4x2x4xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { - %scatter = linalg_ext.scatter - dimension_map = [1] - ranged_data(true) - overlap_window(false) - ins(%B, %A: tensor<4x2x4xf32>, tensor<4x1xi32>) - outs(%C: tensor<16x8xf32>) { - ^bb0(%arg0 :f32, %arg1: f32): - linalg_ext.yield %arg0 : f32 - } -> tensor<16x8xf32> - return %scatter : tensor<16x8xf32> -} - // ----- // CHECK: linalg_ext.scatter func.func @scatter_tensor(%A : tensor<4x1xi32>, %B: tensor<4x2x4xf32>, %C: tensor<16x8xf32>, %D: tensor<4xi1>) -> tensor<16x8xf32> { %scatter = linalg_ext.scatter dimension_map = [1] ranged_data(true) - overlap_window(false) + overlap_window(false) signed_indice(true) ins(%B, %A, %D: tensor<4x2x4xf32>, tensor<4x1xi32>, tensor<4xi1>) outs(%C: tensor<16x8xf32>) { ^bb0(%arg0 :f32, %arg1: f32): @@ -223,7 +196,7 @@ func.func @scatter_tensor_i64_indice(%indices : tensor<4x1xi64>, %window: tensor %scatter = linalg_ext.scatter dimension_map = [1] ranged_data(true) - overlap_window(false) + overlap_window(false) signed_indice(true) ins(%window, %indices, %mask: tensor<4x2x4xf32>, tensor<4x1xi64>, tensor<4xi1>) outs(%data: tensor<16x8xf32>) { ^bb0(%arg0 :f32, %arg1: f32): @@ -238,7 +211,7 @@ func.func @scatter_tensor_i16_indice(%indices : tensor<4x1xi16>, %window: tensor %scatter = linalg_ext.scatter dimension_map = [1] ranged_data(true) - overlap_window(false) + overlap_window(false) signed_indice(true) ins(%window, %indices, %mask: tensor<4x2x4xf32>, tensor<4x1xi16>, tensor<4xi1>) outs(%data: tensor<16x8xf32>) { ^bb0(%arg0 :f32, %arg1: f32): @@ -254,7 +227,7 @@ func.func @scatter_nd_batch( %init : tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = linalg_ext.scatter dimension_map = [0, 1] ranged_data(false) - overlap_window(false) + overlap_window(false) signed_indice(true) ins(%update, %indices : tensor<1x1x2x2xf32>, tensor<1x1x2xi32>) outs(%init : tensor<4x4xf32>) { ^bb0(%arg1: f32, %arg2: f32): @@ -264,26 +237,13 @@ func.func @scatter_nd_batch( return %0 : tensor<4x4xf32> } -// ----- -// CHECK: linalg_ext.gather -func.func @gather_tensor(%A : tensor<4x1xi32>, %B: tensor<4x2x4xf32>, %C: tensor<16x8xf32>) -> tensor<4x2x4xf32> { - %gather = linalg_ext.gather - dimension_map = [1] - ranged_data(true) - ins(%C, %A: tensor<16x8xf32>, tensor<4x1xi32>) - outs(%B: tensor<4x2x4xf32>) { - ^bb0(%arg0 :f32, %arg1: f32): - linalg_ext.yield %arg0 : f32 - } -> tensor<4x2x4xf32> - return %gather : tensor<4x2x4xf32> -} // ----- // CHECK: linalg_ext.gather func.func @gather_tensor(%A : tensor<4x1xi32>, %B: tensor<4x2x4xf32>, %C: tensor<16x8xf32>, %D: tensor<4xi1>) -> tensor<4x2x4xf32> { %gather = linalg_ext.gather dimension_map = [1] - ranged_data(true) + ranged_data(true) signed_indice(false) ins(%C, %A, %D: tensor<16x8xf32>, tensor<4x1xi32>, tensor<4xi1>) outs(%B: tensor<4x2x4xf32>) { ^bb0(%arg0 :f32, %arg1: f32): @@ -292,26 +252,12 @@ func.func @gather_tensor(%A : tensor<4x1xi32>, %B: tensor<4x2x4xf32>, %C: tensor return %gather : tensor<4x2x4xf32> } -// ----- -// CHECK: linalg_ext.gather -func.func @gather_tensor(%indices : tensor<4x1xi64>, %window: tensor<4x2x4xf32>, %data: tensor<16x8xf32>, %mask: tensor<4xi1>) -> tensor<4x2x4xf32> { - %gather = linalg_ext.gather - dimension_map = [1] - ranged_data(true) - ins(%data, %indices, %mask: tensor<16x8xf32>, tensor<4x1xi64>, tensor<4xi1>) - outs(%window: tensor<4x2x4xf32>) { - ^bb0(%arg0 :f32, %arg1: f32): - linalg_ext.yield %arg0 : f32 - } -> tensor<4x2x4xf32> - return %gather : tensor<4x2x4xf32> -} - // ----- // CHECK: linalg_ext.gather func.func @gather_tensor_i8_indice(%indices : tensor<4x1xi8>, %window: tensor<4x2x4xf32>, %data: tensor<16x8xf32>, %mask: tensor<4xi1>) -> tensor<4x2x4xf32> { %gather = linalg_ext.gather dimension_map = [1] - ranged_data(true) + ranged_data(true) signed_indice(false) ins(%data, %indices, %mask: tensor<16x8xf32>, tensor<4x1xi8>, tensor<4xi1>) outs(%window: tensor<4x2x4xf32>) { ^bb0(%arg0 :f32, %arg1: f32): @@ -326,7 +272,7 @@ func.func @gather_nd_batch( %input : tensor<4x4xf32>) -> tensor<1x1x2x2xf32> { // expected-error @+1 {{indexed shape of init value dim#2 exceeds input value at dim#0 1 .vs. 4}} %0 = linalg_ext.gather dimension_map = [0, 1] - ranged_data(false) + ranged_data(false) signed_indice(false) ins(%input, %indices : tensor<4x4xf32>, tensor<1x1x2xi32>) outs(%init : tensor<1x1x2x2xf32>) { ^bb0(%arg1: f32, %arg2: f32): @@ -336,17 +282,31 @@ func.func @gather_nd_batch( return %0 : tensor<1x1x2x2xf32> } +// ----- +// CHECK: linalg_ext.gather_atomic_rmw +func.func @discrete_atomic_addf_with_mask(%arg0: memref<4x1xf32, 101>, %arg1: memref<4x1xi32, 101>, %arg2: memref<4xi8, 101>, %arg3: memref, %arg4: memref<4x1xf32, 101>) { + linalg_ext.gather_atomic_rmw addf relaxed ins(%arg0, %arg1, %arg2 : memref<4x1xf32, 101>, memref<4x1xi32, 101>, memref<4xi8, 101>) outs(%arg3, %arg4 : memref, memref<4x1xf32, 101>) + return +} + +// ----- +// CHECK: linalg_ext.atomic_rmw +func.func @atomic_contiguous(%alloc_5: memref<4xf32, 101>, %view_memref: memref<4xf32, 1>, %alloc_4: memref<4xf32, 101>) { + linalg_ext.atomic_rmw addf release ins(%alloc_5 : memref<4xf32, 101>) outs(%view_memref, %alloc_4 : memref<4xf32, 1>, memref<4xf32, 101>) -> memref<4xf32, 1>, memref<4xf32, 101> + return +} + // ----- // CHECK: linalg_ext.atomic_cas func.func @atomic_cas(%arg0: tensor<128xi32>, %cmp: tensor<128xi32>, %val: tensor<128xi32>, %init: tensor<128xi32>) -> tensor<128xi32> { - %0 = linalg_ext.atomic_cas ins(%arg0, %cmp, %val : tensor<128xi32>, tensor<128xi32>, tensor<128xi32>) outs(%init : tensor<128xi32>) -> tensor<128xi32> + %0 = linalg_ext.atomic_cas relaxed ins(%arg0, %cmp, %val : tensor<128xi32>, tensor<128xi32>, tensor<128xi32>) outs(%init : tensor<128xi32>) -> tensor<128xi32> return %0: tensor<128xi32> } // ----- // CHECK: linalg_ext.gather_atomic_cas func.func @gather_atomic_cas(%in: tensor, %cmp: tensor<128xi32>, %val: tensor<128xi32>, %indice: tensor<128xi64>, %init: tensor<128xi32>) -> tensor<128xi32> { - %4 = linalg_ext.gather_atomic_cas ins(%in, %cmp, %val, %indice: tensor, tensor<128xi32>, tensor<128xi32>, tensor<128xi64>) outs(%init : tensor<128xi32>) -> tensor<128xi32> + %4 = linalg_ext.gather_atomic_cas release ins(%in, %cmp, %val, %indice: tensor, tensor<128xi32>, tensor<128xi32>, tensor<128xi64>) outs(%init : tensor<128xi32>) -> tensor<128xi32> return %4: tensor<128xi32> } @@ -458,6 +418,24 @@ func.func @libdevice_call_total_dynamic(%lhs: memref, %rhs: memr func.return } +// ----- +func.func @scalar_libdevice_call_one_input(%arg0: f32) -> f32 { + // CHECK: linalg_ext.scalar_libdevice_call + %libdevicecall = linalg_ext.scalar_libdevice_call + ins(%arg0 : f32) + symbol = "__cn_scalar_abs_f32" -> f32 + return %libdevicecall : f32 +} + +// ----- +func.func @scalar_libdevice_call_two_input(%arg0: f32) -> f32 { + // CHECK: linalg_ext.scalar_libdevice_call + %libdevicecall = linalg_ext.scalar_libdevice_call + ins(%arg0, %arg0 : f32, f32) + symbol = "__cn_scalar_add_f32" -> f32 + return %libdevicecall : f32 +} + // ----- // CHECK: linalg_ext.pad func.func @pad_memref(%input : memref<4x4x16xf32>, %init : memref<6x8x16xf32>, %pvalue : f32) { @@ -573,6 +551,7 @@ func.func @scan_tensor(%input: tensor<16x32x64xf32>, ins(%input: tensor<16x32x64xf32>) outs(%output, %init: tensor<16x32x64xf32>, tensor<16x64xf32>) dimensions = [1] + reverse = false { ^bb0(%in: f32, %out: f32, %ini: f32): %0 = arith.addf %ini, %in: f32 @@ -590,6 +569,7 @@ func.func @scan_memref(%input: memref<16x32x64xf32>, ins(%input: memref<16x32x64xf32>) outs(%output, %init: memref<16x32x64xf32>, memref<16x64xf32>) dimensions = [1] + reverse = false { ^bb0(%in: f32, %out: f32, %ini: f32): %0 = arith.addf %ini, %in: f32 diff --git a/test/Dialect/Triton/extract-move-backward.mlir b/test/Dialect/Triton/extract-move-backward.mlir index d02d1ab..508146f 100644 --- a/test/Dialect/Triton/extract-move-backward.mlir +++ b/test/Dialect/Triton/extract-move-backward.mlir @@ -182,9 +182,9 @@ func.func @extract_element_from_map_arith_add_with_two_indices( // ----- // CHECK-LABEL: @extract_element_from_map_with_two_payloads func.func @extract_element_from_map_with_two_payloads(%arg0: tensor<32xi64>, %arg1: tensor<32xi32>) -> i64 { - // CHECK: %[[C0_INDEX:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_INDEX:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 + // CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 %c1 = arith.constant 1 : i64 // CHECK: %[[ARG_EXTRA0:.*]] = tensor.extract %arg0[%[[C0_INDEX]]] : tensor<32xi64> // CHECK: %[[ARG_EXTRA1:.*]] = tensor.extract %arg1[%[[C0_INDEX]]] : tensor<32xi32> @@ -691,7 +691,7 @@ func.func @extract_from_for_iter_args_failed(%arg0: i64, %arg1: tensor<64x64xf32 // CHECK: %[[VAL_6:.*]] = tensor.extract_slice %[[VAL_5]][%[[VAL_2]], %[[VAL_2]]] [64, 64] [1, 1] : tensor<64x?xf32> to tensor<64x64xf32> // CHECK: %[[VAL_7:.*]] = linalg.map { math.absf } ins(%[[VAL_6]] : tensor<64x64xf32>) outs(%[[ARG5]] : tensor<64x64xf32>) // CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<64x64xi64> -// CHECK: %[[VAL_9:.*]] = linalg.map { arith.addi } ins(%[[ARG6]], %[[ARG3]] : tensor<64x64xi64>, tensor<64x64xi64>) outs(%[[VAL_8]] : tensor<64x64xi64>) +// CHECK: %[[VAL_9:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[ARG6]], %[[ARG3]] : tensor<64x64xi64>, tensor<64x64xi64>) outs(%[[VAL_8]] : tensor<64x64xi64>) // CHECK: %[[VAL_10:.*]] = tensor.empty() : tensor<64x64xi64> // CHECK: %[[VAL_11:.*]] = linalg.transpose ins(%[[VAL_9]] : tensor<64x64xi64>) outs(%[[VAL_10]] : tensor<64x64xi64>) permutation = [1, 0] // CHECK: scf.yield %[[VAL_7]], %[[VAL_11]] : tensor<64x64xf32>, tensor<64x64xi64> diff --git a/test/Dialect/Triton/extractslice-move-backward.mlir b/test/Dialect/Triton/extractslice-move-backward.mlir index b50b216..90f4daf 100644 --- a/test/Dialect/Triton/extractslice-move-backward.mlir +++ b/test/Dialect/Triton/extractslice-move-backward.mlir @@ -152,7 +152,7 @@ func.func @extract_slice_from_collapse_shape_op_with_0_rank(%arg0 : tensor<1x1xf // CHECK: %[[VAL_2:.*]] = tensor.extract_slice %[[VAL_0]][0, 0] [16, 1] [2, 2] : tensor<128x16xi32> to tensor<16x1xi32> // CHECK: %[[VAL_3:.*]] = tensor.extract_slice %[[VAL_1]][0, 0] [16, 1] [2, 2] : tensor<128x16xi32> to tensor<16x1xi32> // CHECK: %[[VAL_4:.*]] = tensor.empty() : tensor<16x1xi32> -// CHECK: %[[VAL_5:.*]] = linalg.map { arith.addi } ins(%[[VAL_2]], %[[VAL_3]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_4]] : tensor<16x1xi32>) +// CHECK: %[[VAL_5:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[VAL_2]], %[[VAL_3]] : tensor<16x1xi32>, tensor<16x1xi32>) outs(%[[VAL_4]] : tensor<16x1xi32>) // CHECK: %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : tensor<16x1xi32> into tensor<16xi32> // CHECK: return %[[VAL_6]] : tensor<16xi32> // CHECK: } @@ -394,7 +394,7 @@ func.func @extractslice_outside_failed(%arg0: i64, %arg1: tensor<64x64xf32>, %ar // CHECK: "test.foo"(%[[VAL_8]]) : (tensor<128xi32>) -> () // CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_0]][0, 0] [128, 1] [1, 1] : tensor<128x64xi32> to tensor<128x1xi32> // CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<128x1xi32> -// CHECK: %[[VAL_12:.*]] = linalg.map { arith.addi } ins(%[[VAL_9]], %[[VAL_10]] : tensor<128x1xi32>, tensor<128x1xi32>) outs(%[[VAL_11]] : tensor<128x1xi32>) +// CHECK: %[[VAL_12:.*]] = linalg.map { arith.addi {overflowFlags = #arith.overflow} } ins(%[[VAL_9]], %[[VAL_10]] : tensor<128x1xi32>, tensor<128x1xi32>) outs(%[[VAL_11]] : tensor<128x1xi32>) // CHECK: %[[VAL_13:.*]] = tensor.collapse_shape %[[VAL_12]] {{\[\[}}0, 1]] : tensor<128x1xi32> into tensor<128xi32> // CHECK: scf.yield %[[VAL_13]] : tensor<128xi32> // CHECK: } diff --git a/test/Pipelines/pipeline.mlir b/test/Pipelines/pipeline.mlir index acf622b..965be08 100644 --- a/test/Pipelines/pipeline.mlir +++ b/test/Pipelines/pipeline.mlir @@ -1,5 +1,44 @@ // RUN: triton-linalg-opt %s -triton-to-linalg -split-input-file | FileCheck %s +// CHECK-LABEL: func.func @add_kernel_01234 +// CHECK-SAME: (%[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64, %[[ARG2:.+]]: i64, %[[ARG3:.+]]: i32) +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: %[[C1024_I32:.+]] = arith.constant 1024 : i32 +// CHECK-DAG: %[[C0_F32:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL0:.+]] = tt.get_program_id x : i32 +// CHECK: %[[VAL1:.+]] = arith.muli %[[VAL0]], %[[C1024_I32]] : i32 +// CHECK: %[[VAL2:.+]] = arith.index_cast %[[VAL1]] : i32 to index +// CHECK: %[[VAL3:.+]] = arith.addi %[[VAL2]], %[[C1024]] : index +// CHECK: %[[VAL4:.+]] = arith.index_cast %[[ARG3]] : i32 to index +// CHECK: %[[VAL5:.+]] = arith.maxsi %[[VAL4]], %[[VAL2]] : index +// CHECK: %[[VAL6:.+]] = arith.minsi %[[VAL3]], %[[VAL5]] : index +// CHECK: %[[VAL7:.+]] = arith.subi %[[VAL6]], %[[VAL2]] : index +// CHECK: %[[VAL8:.+]] = llvm.inttoptr %[[ARG0]] : i64 to !llvm.ptr +// CHECK: %[[VIEW:.+]] = aux.view %[[VAL8]] to offset: [%[[VAL2]]], sizes: [%[[VAL7]]], strides: [1] : !llvm.ptr to memref +// CHECK: %[[VAL9:.+]] = bufferization.to_tensor %[[VIEW]] restrict writable : memref +// CHECK: %[[VAL10:.+]] = tensor.empty(%[[VAL7]]) : tensor +// CHECK: %[[VAL11:.+]] = linalg.copy ins(%[[VAL9]] : tensor) outs(%[[VAL10]] : tensor) -> tensor +// CHECK: %[[VAL12:.+]] = tensor.empty() : tensor<1024xf32> +// CHECK: %[[VAL13:.+]] = arith.subi %[[C1024]], %[[VAL7]] : index +// CHECK: %[[VAL14:.+]] = linalg_ext.pad ins(%[[VAL11]] : tensor) outs(%[[VAL12]] : tensor<1024xf32>) pvalue(%[[C0_F32]] : f32) low = [0] high = [%[[VAL13]]] { +// CHECK: ^bb0(%[[ARG4:.+]]: f32): +// CHECK: linalg_ext.yield %[[ARG4]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[VAL15:.+]] = llvm.inttoptr %[[ARG1]] : i64 to !llvm.ptr +// CHECK: %[[VIEW0:.+]] = aux.view %[[VAL15]] to offset: [%[[VAL2]]], sizes: [%[[VAL7]]], strides: [1] : !llvm.ptr to memref +// CHECK: %[[VAL16:.+]] = bufferization.to_tensor %[[VIEW0]] restrict writable : memref +// CHECK: %[[VAL17:.+]] = linalg.copy ins(%[[VAL16]] : tensor) outs(%[[VAL10]] : tensor) -> tensor +// CHECK: %[[VAL18:.+]] = linalg_ext.pad ins(%[[VAL17]] : tensor) outs(%[[VAL12]] : tensor<1024xf32>) pvalue(%[[C0_F32]] : f32) low = [0] high = [%[[VAL13]]] { +// CHECK: ^bb0(%[[ARG4]]: f32): +// CHECK: linalg_ext.yield %[[ARG4]] : f32 +// CHECK: } -> tensor<1024xf32> +// CHECK: %[[MAPPED:.+]] = linalg.map { arith.addf } ins(%[[VAL14]], %[[VAL18]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL12]] : tensor<1024xf32>) +// CHECK: %[[VAL19:.+]] = llvm.inttoptr %[[ARG2]] : i64 to !llvm.ptr +// CHECK: %[[VIEW1:.+]] = aux.view %[[VAL19]] to offset: [%[[VAL2]]], sizes: [%[[VAL7]]], strides: [1] : !llvm.ptr to memref +// CHECK: %[[EXTRACED:.+]] = tensor.extract_slice %[[MAPPED]][0] [%[[VAL7]]] [1] : tensor<1024xf32> to tensor +// CHECK: bufferization.materialize_in_destination %[[EXTRACED]] in writable %[[VIEW1]] : (tensor, memref) -> () +// CHECK: return + tt.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { %c1024_i32 = arith.constant 1024 : i32 %0 = tt.get_program_id x : i32 diff --git a/test/lit.cfg.py b/test/lit.cfg.py index da70bff..11b0425 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -45,7 +45,7 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.triton_linalg_obj_root, 'test') config.triton_tools_dir = os.path.join(config.triton_linalg_obj_root, - 'tools/triton-linalg-opt') + 'bin') config.filecheck_dir = os.path.join(config.triton_obj_root, 'bin', 'FileCheck') tool_dirs = [ diff --git a/tools/ci/daily/triton-linalg_daliy.pipeline b/tools/ci/daily/triton-linalg_daliy.pipeline new file mode 100644 index 0000000..29affda --- /dev/null +++ b/tools/ci/daily/triton-linalg_daliy.pipeline @@ -0,0 +1,130 @@ +library "cambricon-pipe-lib@master" +cnpipe { + task('clone') { + stage 'clone' + node { + labelSelector "cambricon.com/mm-daily":true + cardType 'MLU370' + } + container { + networkPolicy "cncl-no-internnet-access" + image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' + runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" + } + resReq { + reqMlus 1 + lmtMlus 1 + reqCpu 30 + lmtCpu 30 + reqMemory '40Gi' + lmtMemory '40Gi' + } + stage 'clone' + script ''' + git clone https://github.com/Cambricon/triton-linalg.git + cd triton-linalg + git fetch origin pull/${pr_id}/head:local_test + git config --global url."http://gitmirror.cambricon.com/git_repos/".insteadOf https:// + git submodule update --init --recursive + git checkout local_test + git log -1 + cd .. + ''' + timeout 30 + stash 'triton-linalg', 'triton-linalg-pr' + } + task('check_pr') { + stage 'check_pr' + node { + labelSelector "cambricon.com/mm-daily":true + cardType 'MLU370' + } + container { + networkPolicy "cncl-no-internnet-access" + image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' + runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" + } + resReq { + reqMlus 1 + lmtMlus 1 + reqCpu 30 + lmtCpu 30 + reqMemory '40Gi' + lmtMemory '40Gi' + } + unstash 'triton-linalg-pr' + script ''' + mkdir logs + set -e + cd triton-linalg + set -o pipefail + bash tools/scripts/lint_check/lint.sh | tee ${CI_WORK_DIR}/logs/link_log || exit 1 + ''' + stash 'triton-linalg', 'triton-linalg-check' + stash 'logs', 'task_logs' + archiveLog 'logs/', false + } + task('build') { + stage 'build' + node { + labelSelector "cambricon.com/mm-daily":true + cardType 'MLU370' + } + container { + networkPolicy "cncl-no-internnet-access" + image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' + runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" + } + resReq { + reqMlus 1 + lmtMlus 1 + reqCpu 30 + lmtCpu 30 + reqMemory '40Gi' + lmtMemory '40Gi' + } + unstash 'triton-linalg-pr' + script ''' + mkdir logs + set -e + export TRITON_PLUGIN_DIRS=${CI_WORK_DIR}/triton-linalg + cd triton-linalg/triton + set -o pipefail + TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true pip3 install -e python --no-build-isolation -vvv | tee ${CI_WORK_DIR}/logs/build_log || exit 1 + ''' + stash 'triton-linalg', 'triton-linalg-build' + stash 'logs', 'task_logs' + archiveLog 'logs/', false + } + task('test') { + stage 'test' + node { + labelSelector "cambricon.com/mm-daily":true + cardType 'MLU370' + } + container { + networkPolicy "cncl-no-internnet-access" + image 'yellow.hub.cambricon.com/genesis/devel/x86_64/triton_linalg:1.0.0-x86_64-ubuntu2004-prebuild-thirdparty-py_3_10' + runArgs "--network=host --privileged -v /usr/bin/cnmon:/usr/bin/cnmon --device=/dev/cambricon_dev0:/dev/cambricon_dev0 --device=/dev/cambricon_ctl" + } + resReq { + reqMlus 1 + lmtMlus 1 + reqCpu 30 + lmtCpu 30 + reqMemory '40Gi' + lmtMemory '40Gi' + } + unstash 'triton-linalg-build' + script ''' + mkdir logs + set -e + cd triton-linalg + set -o pipefail + bash tools/scripts/test_triton-linalg.sh test_linalg_unittest | tee ${CI_WORK_DIR}/logs/test_log || exit 1 + ''' + stash 'triton-linalg', 'triton-linalg-test' + stash 'logs', 'task_logs' + archiveLog 'logs/', false + } +} diff --git a/tools/scripts/lint_check/common.sh b/tools/scripts/lint_check/common.sh new file mode 100644 index 0000000..86babd1 --- /dev/null +++ b/tools/scripts/lint_check/common.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (C) [2022-2025] by Cambricon. + +FINAL_RET=0 +LATEST_RET=0 + +function update_ret() { + LATEST_RET="$?" + if [[ "${LATEST_RET}" -gt "${FINAL_RET}" ]]; then + FINAL_RET="${LATEST_RET}" + fi +} + +# Update the exit code after every command +function enable_update_ret() { + trap update_ret DEBUG +} + + +function check_ret() { + if (( "${FINAL_RET}" != 0 )); then + echo "Encountered failures. Check error messages and changes to the working" \ + "directory and git index (which may contain fixes) and try again." + fi + + exit "${FINAL_RET}" +} diff --git a/tools/scripts/lint_check/format_diff.py b/tools/scripts/lint_check/format_diff.py new file mode 100755 index 0000000..5e72dfd --- /dev/null +++ b/tools/scripts/lint_check/format_diff.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# +#===- format_diff.py - Diff Reformatter ----*- python3 -*--===# +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +#===------------------------------------------------------------------------===# +""" +This script reads input from a unified diff and reformats all the changed +lines. This is useful to reformat all the lines touched by a specific patch. +Example usage: + + git diff -U0 HEAD^ | python3 format_diff.py yapf -i + git diff -U0 HEAD^ | python3 format_diff.py clang-format -i + svn diff --diff-cmd=diff -x-U0 | python3 format_diff.py -p0 clang-format -i + +General usage: + | python3 format_diff.py [--regex] [--lines-style] [-p] binary [args for binary] + +It should be noted that the filename contained in the diff is used unmodified +to determine the source file to update. Users calling this script directly +should be careful to ensure that the path in the diff is correct relative to the +current working directory. +""" + +import argparse +import difflib +import io +import re +import subprocess +import sys + +BINARY_TO_DEFAULT_REGEX = { + "yapf": + r".*\.py", + "clang-format": + r".*\.(cpp|cc|c\+\+|cxx|c|cl|h|hh|hpp|hxx|m|mm|inc|js|ts|proto|" + r"protodevel|java|cs)", +} + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "binary", + help="Location of binary to use for formatting. This controls the " + "default values of --regex and --lines-style. If binary isn't 'yapf' " + "or 'clang-format' then --regex and --lines-style are required.") + parser.add_argument( + "--regex", + metavar="PATTERN", + default=None, + help="Regex pattern for selecting file paths to reformat from the piped " + "diff. This flag is required if 'binary' is not set to 'yapf' or " + "'clang-format'. Otherwise, this flag overrides the default pattern that " + "--binary sets.") + parser.add_argument( + "--lines-style", + default=None, + help= + "How to style the 'lines' argument for the given binary. Can be set " + "to 'yapf' or 'clang-format'. This flag is required if 'binary' is not " + "set to 'yapf' or 'clang-format'.") + parser.add_argument( + "-p", + metavar="NUM", + default=1, + help="Strip the smallest prefix containing P slashes. Set to 0 if " + "passing `--no-prefix` to `git diff` or using `svn`") + + # Parse and error-check arguments + args, binary_args = parser.parse_known_args() + if args.binary not in BINARY_TO_DEFAULT_REGEX: + if not args.regex: + raise parser.error( + "If 'binary' is not 'yapf' or 'clang-format' then " + "--regex must be set.") + if not args.lines_style: + raise parser.error( + "If 'binary' is not 'yapf' or 'clang-format' then " + "--lines-style must be set.") + else: + # Set defaults based off of 'binary'. + if not args.regex: + args.regex = BINARY_TO_DEFAULT_REGEX[args.binary] + if not args.lines_style: + args.lines_style = args.binary + + if args.lines_style not in ["yapf", "clang-format"]: + raise parser.error( + f"Unexpected value for --line-style {args.lines_style}") + + return args, binary_args + + +def main(): + args, binary_args = parse_arguments() + + # Extract changed lines for each file. + filename = None + lines_by_file = {} + for line in sys.stdin: + # Match all filenames. + match = re.search(fr"^\+\+\+\ (.*?/){{{args.p}}}(\S*)", line) + if match: + filename = match.group(2) + if filename is None: + continue + + # Match all filenames specified by --regex. + if not re.match(f"^{args.regex}$", filename): + continue + + # Match unified diff line numbers. + match = re.search(r"^@@.*\+(\d+)(,(\d+))?", line) + if match: + start_line = int(match.group(1)) + line_count = 1 + if match.group(3): + line_count = int(match.group(3)) + if line_count == 0: + continue + end_line = start_line + line_count - 1 + + if args.lines_style == "yapf": + lines = ["--lines", f"{start_line}-{end_line}"] + elif args.lines_style == "clang-format": + lines = ["-lines", f"{start_line}:{end_line}"] + lines_by_file.setdefault(filename, []).extend(lines) + + # Pass the changed lines to 'binary' alongside any unparsed args (e.g. -i). + for filename, lines in lines_by_file.items(): + command = [args.binary, filename] + command.extend(lines) + command.extend(binary_args) + + print(f"Running `{' '.join(command)}`") + p = subprocess.Popen(command, + stdout=subprocess.PIPE, + stderr=None, + stdin=subprocess.PIPE, + universal_newlines=True) + stdout, stderr = p.communicate() + if p.returncode != 0: + sys.exit(p.returncode) + + # If the formatter printed the formatted code to stdout then print out + # a unified diff between the formatted and unformatted code. + # If flags like --verbose are passed to the binary then the diffs this + # produces won't be particularly helpful. + formatted_code = io.StringIO(stdout).readlines() + if len(formatted_code): + with open(filename) as f: + unformatted_code = f.readlines() + diff = difflib.unified_diff(unformatted_code, + formatted_code, + fromfile=filename, + tofile=filename, + fromfiledate="(before formatting)", + tofiledate="(after formatting)") + diff_string = "".join(diff) + if len(diff_string) > 0: + sys.stdout.write(diff_string) + + +if __name__ == "__main__": + main() diff --git a/tools/scripts/lint_check/lint.sh b/tools/scripts/lint_check/lint.sh new file mode 100755 index 0000000..3686eef --- /dev/null +++ b/tools/scripts/lint_check/lint.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright (C) [2022-2025] by Cambricon. + +# ============================================================================== +# Copyright 2021 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# ============================================================================== + +# Runs all the lint checks that we run on Gitlab locally. + +# WARNING: this script *makes changes* to the working directory and the index. + +set -uo pipefail + +FINAL_RET=0 +LATEST_RET=0 + +SCRIPTS_DIR="$(dirname $0)" +source ${SCRIPTS_DIR}/common.sh +BASE_REF="${1:-master}" +pip install yapf -y + +apt install clang-format -y + +enable_update_ret + +echo $BASE_REF + +echo "***** yapf *****" +# Don't fail script if condition is false +files=`find . -name "*.py" |grep -v "./triton/"` +SKIP_FILE_LIST="./triton/" +SKIP_FILE_LIST=$(echo $SKIP_FILE_LIST | tr ' ' '\n') +for file in $files; do + if echo "$SKIP_FILE_LIST" | grep -Fxq "$file"; then + echo "***** skip $file for yapf *****" + continue + fi + yapf $file -i --lines 1-2000 +done + +echo "***** clang-format *****" +git-clang-format --style=file $BASE_REF +git diff --exit-code -- ./lint_check + +check_ret diff --git a/tools/scripts/test_triton-linalg.sh b/tools/scripts/test_triton-linalg.sh new file mode 100644 index 0000000..9596d61 --- /dev/null +++ b/tools/scripts/test_triton-linalg.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +function check_ret(){ + if [ $1 -ne 0 ]; then exit 1; fi +} + +function test_linalg_unittest() { + mkdir -p ${CI_WORK_DIR}/test_logs + export TRITON_PLUGIN_DIRS=${CI_WORK_DIR}/triton-linalg + pip3 install lit + pushd ${CI_WORK_DIR}/triton-linalg/triton/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_linalg + if lit test \ + --xunit-xml-output ${CI_WORK_DIR}/test_logs/test_linalg_unittest_results.xml; + then + error=0 + else + error=1 + fi + popd + check_ret ${error} +} + +function print_usage() { + RED='\033[0;31m' + BLUE='\033[0;34m' + BOLD='\033[1m' + NONE='\033[0m' + + echo -e "${BOLD}bash ci_daily.sh${NONE} Command [Options]" + + echo -e "\n${RED}Command${NONE}: + + ${BLUE}test_linalg_unittest${NONE}: Run tests + + ${BLUE}usage${NONE}: display this message + " +} + +# main entry +function main() { + local cmd=$1 + N_JOBS="" + if [ ! -z "$2" ]; then + N_JOBS=$2 + fi + case $cmd in + test_linalg_unittest) + test_linalg_unittest + ;; + usage) + print_usage + ;; + *) + print_usage + ;; + esac +} + +main $@ +