Skip to content

Commit

Permalink
Autotuner from hell
Browse files Browse the repository at this point in the history
  • Loading branch information
mikex86 committed Sep 10, 2024
1 parent c8f2d17 commit 4e718ee
Show file tree
Hide file tree
Showing 15 changed files with 1,825 additions and 614 deletions.
28 changes: 25 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ cmake_minimum_required(VERSION 3.28)

project(tritonc CXX)

# Enable exceptions
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
add_compile_options(-fexceptions)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
add_compile_options(/EHsc)
endif ()

set(CMAKE_CXX_STANDARD 17)

set(TRITON_CODEGEN_BACKENDS "nvidia" CACHE STRING "Enable different codegen backends" FORCE)
Expand All @@ -12,11 +19,26 @@ set(PYTHON_INCLUDE_DIRS ON)
add_subdirectory(third_party/argparse)
add_subdirectory(third_party/triton)

add_executable(tritonc src/main.cpp)
add_executable(tritonc src/main.cpp
src/tritonc.cpp
src/autotuner.cpp
src/kernel_runner.cpp
src/tinyscript.cpp
src/templating.cpp)
target_link_libraries(tritonc PRIVATE argparse)
target_link_libraries(tritonc PRIVATE cuda curand)

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)

# we need to hackishly turn an assert into an exception because it is violated very often...
# we thus need to define something to override assert() for target TritonTools
# we manually set the include guard _ASSERT_H_DECLS so that __assert_fail doesn't get defined.
# we declare a macro that redirects __assert_fail
# (yikes)
target_compile_definitions(TritonTools PRIVATE "_ASSERT_H_DECLS=1")
target_compile_options(TritonTools PRIVATE -include "${CMAKE_SOURCE_DIR}/include_hacks/fake_except_assert.h")
target_compile_options(TritonTools PRIVATE -Wno-terminate) # grr

set(TRITON_LIBRARIES
${triton_libs}

Expand Down Expand Up @@ -70,9 +92,9 @@ elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le")
LLVMPowerPCAsmParser
LLVMPowerPCCodeGen
)
else()
else ()
message(FATAL_ERROR "LLVM codegen/ASM parser libs: This HW architecture (${CMAKE_SYSTEM_PROCESSOR}) is not configured in cmake lib dependencies.")
endif()
endif ()

target_link_libraries(tritonc PRIVATE ${TRITON_LIBRARIES})

Expand Down
11 changes: 11 additions & 0 deletions include_hacks/fake_except_assert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// config_assert.h
#pragma once

#define __assert_fail(expr, file, line, function) throw std::runtime_error("Assertion failed")

// we need this assert in release mode as well...
#define assert(expr) \
(static_cast <bool> (expr) \
? void (0) \
: __assert_fail (#expr, __ASSERT_FILE, __ASSERT_LINE, \
__ASSERT_FUNCTION))
148 changes: 148 additions & 0 deletions matmul_autotune.ttir
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#pragma autotune BLOCK_SIZE_M {64, 128, 256}
#pragma autotune BLOCK_SIZE_N {64, 128, 256}
#pragma autotune BLOCK_SIZE_K {32, 64}
#pragma autotune GROUP_SIZE_M {12, 16, 20, 24}
#pragma autotune intrinsic num_warps {2, 4, 8, 16}
#pragma autotune intrinsic num_stages {3, 4, 5}

#pragma argument 0 ptr cuMalloc(8192 * 8192 * 2)
#pragma argument 1 ptr cuMalloc(8192 * 8192 * 2)
#pragma argument 2 ptr cuMalloc(8192 * 8192 * 2)

// %arg3 = M
#pragma argument 3 i32 8192
// %arg4 = N
#pragma argument 4 i32 8192
// %arg5 = K
#pragma argument 5 i32 8192

// %arg6 = stride_am
#pragma argument 6 i32 8192
// %arg7 = stride_bk
#pragma argument 7 i32 8192
// %arg8 = stride_cm
#pragma argument 8 i32 8192

#pragma grid x ((8192 / ${BLOCK_SIZE_M}) * (8192 / ${BLOCK_SIZE_N}))

#pragma launch matmul_kernel

module {
tt.func public @matmul_kernel(
%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0> : tensor<1x${BLOCK_SIZE_N}xi64>
%cst_0 = arith.constant dense<0> : tensor<${BLOCK_SIZE_M}x1xi64>
%c32_i64 = arith.constant ${BLOCK_SIZE_K} : i64
%c0_i64 = arith.constant 0 : i64
%c255_i32 = arith.constant ${${BLOCK_SIZE_N} - 1} : i32
%c127_i32 = arith.constant ${${BLOCK_SIZE_M} - 1} : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%c32_i32 = arith.constant ${BLOCK_SIZE_K} : i32
%c256_i32 = arith.constant ${BLOCK_SIZE_N} : i32
%c0_i32 = arith.constant 0 : i32
%c128_i32 = arith.constant ${BLOCK_SIZE_M} : i32
%c24_i32 = arith.constant ${GROUP_SIZE_M} : i32
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c127_i32 : i32
%2 = arith.divsi %1, %c128_i32 : i32
%3 = arith.addi %arg4, %c255_i32 : i32
%4 = arith.divsi %3, %c256_i32 : i32
%5 = arith.muli %4, %c24_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c24_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.minsi %8, %c24_i32 : i32
%10 = arith.remsi %0, %9 : i32
%11 = arith.addi %7, %10 : i32
%12 = arith.remsi %0, %5 : i32
%13 = arith.divsi %12, %9 : i32
%14 = arith.muli %11, %c128_i32 : i32
%15 = arith.extsi %arg6 : i32 to i64
%16 = arith.extsi %14 : i32 to i64
%17 = arith.muli %13, %c256_i32 : i32
%18 = arith.extsi %arg7 : i32 to i64
%19 = arith.extsi %17 : i32 to i64
%20 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>
%21 = tt.splat %16 : i64 -> tensor<${BLOCK_SIZE_M}xi64>
%22 = tt.make_range {end = ${BLOCK_SIZE_M} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_M}xi32>
%23 = arith.extsi %22 : tensor<${BLOCK_SIZE_M}xi32> to tensor<${BLOCK_SIZE_M}xi64>
%24 = arith.addi %21, %23 : tensor<${BLOCK_SIZE_M}xi64>
%25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xi64> -> tensor<${BLOCK_SIZE_M}x1xi64>
%26 = tt.splat %15 : i64 -> tensor<${BLOCK_SIZE_M}x1xi64>
%27 = arith.muli %25, %26 : tensor<${BLOCK_SIZE_M}x1xi64>
%28 = tt.broadcast %27 : tensor<${BLOCK_SIZE_M}x1xi64> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi64>
%29 = tt.make_range {end = ${BLOCK_SIZE_K} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_K}xi32>
%30 = arith.extsi %29 : tensor<${BLOCK_SIZE_K}xi32> to tensor<${BLOCK_SIZE_K}xi64>
%31 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
%32 = tt.splat %18 : i64 -> tensor<${BLOCK_SIZE_K}x1xi64>
%33 = tt.splat %19 : i64 -> tensor<${BLOCK_SIZE_N}xi64>
%34 = tt.make_range {end = ${BLOCK_SIZE_N} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32>
%35 = arith.extsi %34 : tensor<${BLOCK_SIZE_N}xi32> to tensor<${BLOCK_SIZE_N}xi64>
%36 = arith.addi %33, %35 : tensor<${BLOCK_SIZE_N}xi64>
%37 = tt.expand_dims %36 {axis = 0 : i32} : tensor<${BLOCK_SIZE_N}xi64> -> tensor<1x${BLOCK_SIZE_N}xi64>
%38 = tt.broadcast %37 : tensor<1x${BLOCK_SIZE_N}xi64> -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi64>
%39:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst_1, %arg11 = %c0_i64, %arg12 = %c0_i64) -> (tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>, i64, i64) : i32 {
%72 = tt.splat %arg11 : i64 -> tensor<${BLOCK_SIZE_K}xi64>
%73 = arith.addi %72, %30 : tensor<${BLOCK_SIZE_K}xi64>
%74 = tt.expand_dims %73 {axis = 0 : i32} : tensor<${BLOCK_SIZE_K}xi64> -> tensor<1x${BLOCK_SIZE_K}xi64>
%75 = tt.broadcast %74 : tensor<1x${BLOCK_SIZE_K}xi64> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi64>
%76 = arith.addi %28, %75 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi64>
%77 = tt.addptr %20, %76 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xi64>
%78 = tt.load %77 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}x!tt.ptr<f16>>
%79 = tt.splat %arg12 : i64 -> tensor<${BLOCK_SIZE_K}xi64>
%80 = arith.addi %79, %30 : tensor<${BLOCK_SIZE_K}xi64>
%81 = tt.expand_dims %80 {axis = 1 : i32} : tensor<${BLOCK_SIZE_K}xi64> -> tensor<${BLOCK_SIZE_K}x1xi64>
%82 = arith.muli %81, %32 : tensor<${BLOCK_SIZE_K}x1xi64>
%83 = tt.broadcast %82 : tensor<${BLOCK_SIZE_K}x1xi64> -> tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi64>
%84 = arith.addi %83, %38 : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi64>
%85 = tt.addptr %31, %84 : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xi64>
%86 = tt.load %85 : tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
%87 = tt.dot %78, %86, %arg10, inputPrecision = tf32 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_K}xf16> * tensor<${BLOCK_SIZE_K}x${BLOCK_SIZE_N}xf16> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>
%88 = arith.addi %arg11, %c32_i64 : i64
%89 = arith.addi %arg12, %c32_i64 : i64
scf.yield %87, %88, %89 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32>, i64, i64
}
%40 = arith.truncf %39#0 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf32> to tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xf16>
%41 = arith.extsi %arg3 : i32 to i64
%42 = arith.extsi %arg4 : i32 to i64
%43 = arith.extsi %arg8 : i32 to i64
%44 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
%45 = tt.splat %16 : i64 -> tensor<${BLOCK_SIZE_M}xi64>
%46 = tt.make_range {end = ${BLOCK_SIZE_M} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_M}xi32>
%47 = arith.extsi %46 : tensor<${BLOCK_SIZE_M}xi32> to tensor<${BLOCK_SIZE_M}xi64>
%48 = arith.addi %45, %47 : tensor<${BLOCK_SIZE_M}xi64>
%49 = tt.expand_dims %48 {axis = 1 : i32} : tensor<${BLOCK_SIZE_M}xi64> -> tensor<${BLOCK_SIZE_M}x1xi64>
%50 = tt.splat %43 : i64 -> tensor<${BLOCK_SIZE_M}x1xi64>
%51 = arith.muli %49, %50 : tensor<${BLOCK_SIZE_M}x1xi64>
%52 = tt.broadcast %51 : tensor<${BLOCK_SIZE_M}x1xi64> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi64>
%53 = tt.splat %19 : i64 -> tensor<${BLOCK_SIZE_N}xi64>
%54 = tt.make_range {end = ${BLOCK_SIZE_N} : i32, start = 0 : i32} : tensor<${BLOCK_SIZE_N}xi32>
%55 = arith.extsi %54 : tensor<${BLOCK_SIZE_N}xi32> to tensor<${BLOCK_SIZE_N}xi64>
%56 = arith.addi %53, %55 : tensor<${BLOCK_SIZE_N}xi64>
%57 = tt.expand_dims %56 {axis = 0 : i32} : tensor<${BLOCK_SIZE_N}xi64> -> tensor<1x${BLOCK_SIZE_N}xi64>
%58 = tt.broadcast %57 : tensor<1x${BLOCK_SIZE_N}xi64> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi64>
%59 = arith.addi %52, %58 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi64>
%60 = tt.addptr %44, %59 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}x!tt.ptr<f16>>, tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi64>
%61 = arith.cmpi sge, %49, %cst_0 : tensor<${BLOCK_SIZE_M}x1xi64>
%62 = tt.splat %41 : i64 -> tensor<${BLOCK_SIZE_M}x1xi64>
%63 = arith.cmpi slt, %49, %62 : tensor<${BLOCK_SIZE_M}x1xi64>
%64 = arith.andi %61, %63 : tensor<${BLOCK_SIZE_M}x1xi1>
%65 = tt.broadcast %64 : tensor<${BLOCK_SIZE_M}x1xi1> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>
%66 = arith.cmpi sge, %57, %cst : tensor<1x${BLOCK_SIZE_N}xi64>
%67 = tt.splat %42 : i64 -> tensor<1x${BLOCK_SIZE_N}xi64>
%68 = arith.cmpi slt, %57, %67 : tensor<1x${BLOCK_SIZE_N}xi64>
%69 = arith.andi %66, %68 : tensor<1x${BLOCK_SIZE_N}xi1>
%70 = tt.broadcast %69 : tensor<1x${BLOCK_SIZE_N}xi1> -> tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>
%71 = arith.andi %65, %70 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}xi1>
tt.store %60, %40, %71 : tensor<${BLOCK_SIZE_M}x${BLOCK_SIZE_N}x!tt.ptr<f16>>
tt.return
}
}
146 changes: 146 additions & 0 deletions matmul_autotune_bak.ttir
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#pragma autotune BLOCK_SIZE_M {64, 128, 256}
#pragma autotune BLOCK_SIZE_N {64, 128, 256}
#pragma autotune BLOCK_SIZE_K {32, 64}
#pragma autotune GROUP_SIZE_M {8, 12, 16, 20, 24}
#pragma autotune intrinsic num_warps {2, 4, 8}
#pragma autotune intrinsic num_stages {3, 4, 5}

#pragma argument 0 ptr cuMalloc(8192 * 8192 * 2)
#pragma argument 1 ptr cuMalloc(8192 * 8192 * 2)
#pragma argument 2 ptr cuMalloc(8192 * 8192 * 2)

// %arg3 = M
#pragma argument 3 i32 8192
// %arg4 = N
#pragma argument 4 i32 8192
// %arg5 = K
#pragma argument 5 i32 8192

// %arg6 = stride_am
#pragma argument 6 i32 8192
// %arg7 = stride_bk
#pragma argument 7 i32 8192
// %arg8 = stride_cm
#pragma argument 8 i32 8192

#pragma gridX (8192 / ${BLOCK_SIZE_M}) * (8192 / ${BLOCK_SIZE_N})

module {
tt.func public @matmul_kernel(
%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0> : tensor<1x256xi64>
%cst_0 = arith.constant dense<0> : tensor<128x1xi64>
%c32_i64 = arith.constant 32 : i64
%c0_i64 = arith.constant 0 : i64
%c255_i32 = arith.constant 255 : i32
%c127_i32 = arith.constant 127 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xf32>
%c32_i32 = arith.constant 32 : i32
%c256_i32 = arith.constant 256 : i32
%c0_i32 = arith.constant 0 : i32
%c128_i32 = arith.constant 128 : i32
%c${BLOCK_SIZE_K}_i32 = arith.constant ${BLOCK_SIZE_K} : i32
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg3, %c127_i32 : i32
%2 = arith.divsi %1, %c128_i32 : i32
%3 = arith.addi %arg4, %c255_i32 : i32
%4 = arith.divsi %3, %c256_i32 : i32
%5 = arith.muli %4, %c${BLOCK_SIZE_K}_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c${BLOCK_SIZE_K}_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.minsi %8, %c${BLOCK_SIZE_K}_i32 : i32
%10 = arith.remsi %0, %9 : i32
%11 = arith.addi %7, %10 : i32
%12 = arith.remsi %0, %5 : i32
%13 = arith.divsi %12, %9 : i32
%14 = arith.muli %11, %c128_i32 : i32
%15 = arith.extsi %arg6 : i32 to i64
%16 = arith.extsi %14 : i32 to i64
%17 = arith.muli %13, %c256_i32 : i32
%18 = arith.extsi %arg7 : i32 to i64
%19 = arith.extsi %17 : i32 to i64
%20 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>>
%21 = tt.splat %16 : i64 -> tensor<128xi64>
%22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%23 = arith.extsi %22 : tensor<128xi32> to tensor<128xi64>
%24 = arith.addi %21, %23 : tensor<128xi64>
%25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<128xi64> -> tensor<128x1xi64>
%26 = tt.splat %15 : i64 -> tensor<128x1xi64>
%27 = arith.muli %25, %26 : tensor<128x1xi64>
%28 = tt.broadcast %27 : tensor<128x1xi64> -> tensor<128x32xi64>
%29 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%30 = arith.extsi %29 : tensor<32xi32> to tensor<32xi64>
%31 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x256x!tt.ptr<f16>>
%32 = tt.splat %18 : i64 -> tensor<32x1xi64>
%33 = tt.splat %19 : i64 -> tensor<256xi64>
%34 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%35 = arith.extsi %34 : tensor<256xi32> to tensor<256xi64>
%36 = arith.addi %33, %35 : tensor<256xi64>
%37 = tt.expand_dims %36 {axis = 0 : i32} : tensor<256xi64> -> tensor<1x256xi64>
%38 = tt.broadcast %37 : tensor<1x256xi64> -> tensor<32x256xi64>
%39:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst_1, %arg11 = %c0_i64, %arg12 = %c0_i64) -> (tensor<128x256xf32>, i64, i64) : i32 {
%72 = tt.splat %arg11 : i64 -> tensor<32xi64>
%73 = arith.addi %72, %30 : tensor<32xi64>
%74 = tt.expand_dims %73 {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64>
%75 = tt.broadcast %74 : tensor<1x32xi64> -> tensor<128x32xi64>
%76 = arith.addi %28, %75 : tensor<128x32xi64>
%77 = tt.addptr %20, %76 : tensor<128x32x!tt.ptr<f16>>, tensor<128x32xi64>
%78 = tt.load %77 : tensor<128x32x!tt.ptr<f16>>
%79 = tt.splat %arg12 : i64 -> tensor<32xi64>
%80 = arith.addi %79, %30 : tensor<32xi64>
%81 = tt.expand_dims %80 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
%82 = arith.muli %81, %32 : tensor<32x1xi64>
%83 = tt.broadcast %82 : tensor<32x1xi64> -> tensor<32x256xi64>
%84 = arith.addi %83, %38 : tensor<32x256xi64>
%85 = tt.addptr %31, %84 : tensor<32x256x!tt.ptr<f16>>, tensor<32x256xi64>
%86 = tt.load %85 : tensor<32x256x!tt.ptr<f16>>
%87 = tt.dot %78, %86, %arg10, inputPrecision = tf32 : tensor<128x32xf16> * tensor<32x256xf16> -> tensor<128x256xf32>
%88 = arith.addi %arg11, %c32_i64 : i64
%89 = arith.addi %arg12, %c32_i64 : i64
scf.yield %87, %88, %89 : tensor<128x256xf32>, i64, i64
}
%40 = arith.truncf %39#0 : tensor<128x256xf32> to tensor<128x256xf16>
%41 = arith.extsi %arg3 : i32 to i64
%42 = arith.extsi %arg4 : i32 to i64
%43 = arith.extsi %arg8 : i32 to i64
%44 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x256x!tt.ptr<f16>>
%45 = tt.splat %16 : i64 -> tensor<128xi64>
%46 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%47 = arith.extsi %46 : tensor<128xi32> to tensor<128xi64>
%48 = arith.addi %45, %47 : tensor<128xi64>
%49 = tt.expand_dims %48 {axis = 1 : i32} : tensor<128xi64> -> tensor<128x1xi64>
%50 = tt.splat %43 : i64 -> tensor<128x1xi64>
%51 = arith.muli %49, %50 : tensor<128x1xi64>
%52 = tt.broadcast %51 : tensor<128x1xi64> -> tensor<128x256xi64>
%53 = tt.splat %19 : i64 -> tensor<256xi64>
%54 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%55 = arith.extsi %54 : tensor<256xi32> to tensor<256xi64>
%56 = arith.addi %53, %55 : tensor<256xi64>
%57 = tt.expand_dims %56 {axis = 0 : i32} : tensor<256xi64> -> tensor<1x256xi64>
%58 = tt.broadcast %57 : tensor<1x256xi64> -> tensor<128x256xi64>
%59 = arith.addi %52, %58 : tensor<128x256xi64>
%60 = tt.addptr %44, %59 : tensor<128x256x!tt.ptr<f16>>, tensor<128x256xi64>
%61 = arith.cmpi sge, %49, %cst_0 : tensor<128x1xi64>
%62 = tt.splat %41 : i64 -> tensor<128x1xi64>
%63 = arith.cmpi slt, %49, %62 : tensor<128x1xi64>
%64 = arith.andi %61, %63 : tensor<128x1xi1>
%65 = tt.broadcast %64 : tensor<128x1xi1> -> tensor<128x256xi1>
%66 = arith.cmpi sge, %57, %cst : tensor<1x256xi64>
%67 = tt.splat %42 : i64 -> tensor<1x256xi64>
%68 = arith.cmpi slt, %57, %67 : tensor<1x256xi64>
%69 = arith.andi %66, %68 : tensor<1x256xi1>
%70 = tt.broadcast %69 : tensor<1x256xi1> -> tensor<128x256xi1>
%71 = arith.andi %65, %70 : tensor<128x256xi1>
tt.store %60, %40, %71 : tensor<128x256x!tt.ptr<f16>>
tt.return
}
}
Loading

0 comments on commit 4e718ee

Please sign in to comment.