From cc25374fa480c0b3e51cf218ed6fe7eb4c50a5bb Mon Sep 17 00:00:00 2001 From: SJW <48454132+sjw36@users.noreply.github.com> Date: Mon, 11 Nov 2024 18:35:02 -0600 Subject: [PATCH] [AMD][Pipeliner] Improve clustering and add prefetch (#4881) This commit improves pipeliner op clustering so that we can avoid relying complicated and fragile reordering step later. In order to do this, we formalized stages a bit and improved documentation accordingly. Also this commit adds an extra experimental stage to buffer in registers before compute, which is a part of a series of commits to improve scheduling perf. --- .../amd/amd-reorder-instructions.mlir | 345 ------------- test/TritonGPU/amd/amd-sched-2nd-load.mlir | 32 +- test/TritonGPU/loop-pipeline-hip.mlir | 4 +- test/TritonGPU/loop-pipeline.mlir | 475 +++++++++++------- third_party/amd/backend/compiler.py | 3 +- .../include/TritonAMDGPUTransforms/Passes.h | 3 +- .../include/TritonAMDGPUTransforms/Passes.td | 5 +- .../ReorderInstructions.cpp | 6 +- .../StreamPipelineV2.cpp | 317 ++++++++---- third_party/amd/python/triton_amd.cc | 4 +- 10 files changed, 548 insertions(+), 646 deletions(-) diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 5dfd0f2a5f4c..d7be023312ea 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -115,351 +115,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war } } -// ----- -// Move loads (and independent local_stores) as early as possible. -// For example in the matmul_loop below, the scf.for loop looks like this after pipeliner: -// scf.for ... { -// // stage 1 -// %a = tt.local_load %a_tile -// %b = tt.local_load %b_tile -// tt.dot %c, %a, %b -// // stage 0 -// %aptr = tt.addptr %aptr, %k -// %a_next = tt.load %aptr -// %bptr = tt.addptr %bptr, %k -// %b_next = tt.load %bptr -// tt.local_store %a_next -// tt.local_store %b_next -// yield -// } -// -// Solution for num_stages=2 : -// scf.for ... { -// // stage 0.a -// %aptr = tt.addptr %aptr, %k -// %a_next = tt.load %aptr -// %bptr = tt.addptr %bptr, %k -// %b_next = tt.load %bptr -// // stage 1 -// %a = tt.local_load %a_tile -// %b = tt.local_load %b_tile -// tt.dot %c, %a, %b -// // stage 0.b -// tt.local_store %a_next -// tt.local_store %b_next -// yield -// } -// -// Solution for num_stages=3 (double-buffered) : -// scf.for ... { -// // stage 1 -// tt.local_store %a_next_1 -// tt.local_store %b_next_1 -// // stage 0 -// %aptr = tt.addptr %aptr, %k -// %a_next_2 = tt.load %aptr -// %bptr = tt.addptr %bptr, %k -// %b_next_2 = tt.load %bptr -// // stage 2 -// %a = tt.local_load %a_tile -// %b = tt.local_load %b_tile -// tt.dot %c, %a, %b -// yield -// } - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> -#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { - -// CHECK-LABEL: tt.func @matmul_loop -// CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) -// Stage 0.a -// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} -// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_21]] -// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] -// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] -// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] -// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] -// Stage 1 -// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG10]] -// CHECK: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %[[ARG11]] -// CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} -// CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]] -// Stage 0.b -// CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}} -// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} -// CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] -// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] -// CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] -// CHECK: } - - tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> - %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> - %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> - %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> - %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> - %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> - %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> - %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %11 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %12 = arith.cmpi slt, %arg0, %arg1 : index - %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> - %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> - %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> - %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %17 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %18 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { - %20 = arith.subi %arg1, %arg2 : index - %21 = arith.cmpi slt, %arg5, %20 : index - %22 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %23 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %24 = arith.mulf %23, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> - %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> - %29 = tt.load %26, %28 : tensor<128x32x!tt.ptr, #blocked1> - %30 = tt.splat %21 : i1 -> tensor<32x128xi1, #blocked> - %31 = tt.load %27, %30, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %32 = arith.addi %arg9, %c1_i32 : i32 - %33 = arith.cmpi slt, %32, %c1_i32 : i32 - %34 = arith.select %33, %32, %c0_i32 : i32 - %35 = triton_gpu.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %36 = triton_gpu.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - } - triton_gpu.local_dealloc %10 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %11 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - tt.return %19#2 : tensor<128x128xf32, #mma> - } - - -// This example tests that tt.load overlaps with independent ttg.local_store which -// overlaps with independent tt.dot. -// num_stages == 3, double buffered - -// CHECK-LABEL: tt.func @matmul_loop_mb -// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// Stage 0 -// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}} -// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]] -// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]] -// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]] -// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]] -// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]] -// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]] -// Stage 1 -// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} -// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} -// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] -// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] -// Stage 2 -// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] -// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] -// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} -// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] -// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] -// CHECK: } - - tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { - %c2 = arith.constant 2 : index - %c2_i32 = arith.constant 2 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> - %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> - %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> - %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> - %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> - %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> - %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> - %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> - %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %11 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %12 = arith.cmpi slt, %arg0, %arg1 : index - %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> - %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> - %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> - %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %17 = arith.addi %arg0, %arg2 : index - %18 = arith.cmpi slt, %17, %arg1 : index - %19 = tt.addptr %4, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %20 = tt.addptr %9, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %21 = tt.splat %18 : i1 -> tensor<128x32xi1, #blocked1> - %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> - %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> - %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %25 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %26 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { - %28 = arith.muli %arg2, %c2 : index - %29 = arith.subi %arg1, %28 : index - %30 = arith.cmpi slt, %arg5, %29 : index - %31 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %32 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %33 = arith.mulf %32, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> - %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> - %36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1> - %38 = tt.load %35, %37 : tensor<128x32x!tt.ptr, #blocked1> - %39 = tt.splat %30 : i1 -> tensor<32x128xi1, #blocked> - %40 = tt.load %36, %39, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %41 = arith.addi %arg9, %c1_i32 : i32 - %42 = arith.cmpi slt, %41, %c2_i32 : i32 - %43 = arith.select %42, %41, %c0_i32 : i32 - %44 = triton_gpu.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - %45 = triton_gpu.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> - } - triton_gpu.local_dealloc %10 : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %11 : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> - tt.return %27#2 : tensor<128x128xf32, #mma> - } - -// This example shows dependent loads and verifies all are moved early. -// CHECK-LABEL: tt.func @indirect_bmm_vector -// CHECK: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// Stage 0 -// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} -// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] -// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] -// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] -// Stage 1.a -// CHECK: %[[EXPAND_DIMS_25:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} -// CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]] -// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %[[BROADCAST_26]] -// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[MULI_27]] -// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] -// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]] -// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG9]], %{{.*}} -// CHECK: %[[SUBI_32:.*]] = arith.subi %{{.*}}, %{{.*}} -// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_32]] -// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] -// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] -// Stage 2 -// CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]] -// CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]] -// CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]] -// Stage 1.b -// CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}} -// CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} -// CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} -// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] -// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] -// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] -// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] -// CHECK: } - - tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { - %c2 = arith.constant 2 : index - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c1_i32 = arith.constant 1 : i32 - %cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %2 = arith.cmpi sgt, %arg1, %c0 : index - %3 = tt.splat %2 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = arith.cmpi sgt, %arg1, %c1 : index - %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1> - %8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr, #blocked1> - %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> - %10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> - %11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked> - %12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> - %13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked> - %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> - %15 = tt.splat %5 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %17 = triton_gpu.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %18 = triton_gpu.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { - %20 = arith.subi %arg1, %c2 : index - %21 = arith.cmpi slt, %arg6, %20 : index - %22 = arith.subi %arg1, %c1 : index - %23 = arith.cmpi slt, %arg6, %22 : index - %24 = triton_gpu.local_load %arg11 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %25 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> - %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> - %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1> - %30 = tt.load %27, %29 : tensor<16x16x!tt.ptr, #blocked1> - %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> - %32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> - %33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked> - %34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> - %35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked> - %36 = tt.load %34, %35 : tensor<16x16x!tt.ptr, #blocked> - %37 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %39 = arith.addi %arg10, %c1_i32 : i32 - %40 = arith.cmpi slt, %39, %c1_i32 : i32 - %41 = arith.select %40, %39, %c0_i32 : i32 - %42 = triton_gpu.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - %43 = triton_gpu.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - } - triton_gpu.local_dealloc %0 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %1 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> - tt.return %19#0 : tensor<16x16xf32, #mma> - } -} - // ----- // CHECK-LABEL: sink_convert_dealloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir index 5c173ffb4858..6dad361adc16 100644 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -35,11 +35,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> @@ -64,11 +64,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> @@ -81,8 +81,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Should NOT apply: tile size 256x64x128 with single dot // CHECK-LABEL: sink_2nd_load_256x64x128 // CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load // CHECK-NEXT: tt.dot // CHECK-NEXT: triton_gpu.local_store %[[tileA]] @@ -93,11 +93,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x64xf32, #mma> @@ -110,8 +110,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Should NOT apply: tile size 256x256x32 with single dot // CHECK-LABEL: sink_2nd_load_256x256x32 // CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load // CHECK-NEXT: tt.dot // CHECK-NEXT: triton_gpu.local_store %[[tileA]] @@ -122,11 +122,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> - %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> @@ -142,8 +142,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Should NOT apply: the 2nd load has a user before the dot // CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot // CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load +// CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load // CHECK-NEXT: tt.store // CHECK-NEXT: tt.dot @@ -154,10 +154,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { - %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> - %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> + %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> @@ -174,12 +174,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Category 3: two dots in the for loop. Make sure the optimization is not applied // should NOT apply: two dots // CHECK-LABEL: sink_2nd_load_256x256x64_two_dot -// CHECK: triton_gpu.local_load +// CHECK: tt.load +// CHECK-NEXT: tt.load +// CHECK-NEXT: triton_gpu.local_load // CHECK-NEXT: triton_gpu.local_load // CHECK-NEXT: tt.dot // CHECK-NEXT: tt.dot -// CHECK-NEXT: tt.load -// CHECK-NEXT: tt.load // CHECK-NEXT: triton_gpu.local_store // CHECK-NEXT: triton_gpu.local_store #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 3abcc581b906..d6653b2b004c 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -35,9 +35,9 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: triton_gpu.local_store // CHECK: scf.for + // CHECK: tt.load // CHECK: tt.dot // CHECK: tt.dot - // CHECK: tt.load // CHECK: triton_gpu.local_store // CHECK: scf.yield %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { @@ -165,9 +165,9 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 // CHECK-LABEL: tt.func public @add_barrier_kernel // CHECK: tt.load // CHECK: scf.for +// CHECK: tt.load // CHECK: gpu.barrier // CHECK: tt.store -// CHECK: tt.load // CHECK: scf.yield // CHECK: gpu.barrier // CHECK: tt.store diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index c0c4c1860e34..82f726899ea4 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,5 +1,6 @@ // RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK // RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -62,22 +63,22 @@ // AMD-DAG: %[[C0:.*]] = arith.constant 0 : index // AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index // AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) -// AMD: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG10]] -// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[MULF_34:.*]] = arith.mulf %[[LOCAL_LOAD_33]], %{{.*}} -// AMD: %[[DOT_35:.*]] = tt.dot %[[LOCAL_LOAD_32]], %[[MULF_34]], %[[ARG8]] -// AMD: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// AMD: %[[ADDPTR_37:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_36]] -// AMD: %[[LOAD_39:.*]] = tt.load %[[ADDPTR_37]] -// AMD: %[[ADDI_40:.*]] = arith.addi %[[ARG9]], %{{.*}} -// AMD: %[[CMPI_41:.*]] = arith.cmpi slt, %[[ADDI_40]], %{{.*}} -// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_41]], %[[ADDI_40]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_43]] -// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_39]], %[[MEMDESC_SUBVIEW_44]] -// AMD: scf.yield %[[ADDPTR_36]], %[[ADDPTR_37]], %[[DOT_35]], %[[SELECT_42]], %[[MEMDESC_SUBVIEW_43]], %[[MEMDESC_SUBVIEW_44]] +// AMD: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// AMD: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]] +// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG10]] +// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_35]] +// AMD: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[MULF_40:.*]] = arith.mulf %[[LOCAL_LOAD_39]], %{{.*}} +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[MULF_40]], %[[ARG8]] +// AMD: %[[ADDI_42:.*]] = arith.addi %[[ARG9]], %{{.*}} +// AMD: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_45]] +// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_46]] +// AMD: scf.yield %[[ADDPTR_34]], %[[ADDPTR_35]], %[[DOT_41]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]] // AMD: } // AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] // AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]] @@ -99,6 +100,34 @@ // AMD: triton_gpu.local_dealloc %{{.*}} // AMD: triton_gpu.local_dealloc %{{.*}} +// Prefetch pipelining adds another stage in between global load and compute. +// This stage will local_store, then local_load, creating a prefetch from shared +// memory into a register buffer for compute. +// +// AMD_PREFETCH-LABEL: tt.func @matmul_loop +// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return + module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -200,10 +229,12 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview // AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: %[[FOR:.*]]:6 = scf.for -// AMD-COUNT-2: triton_gpu.local_load -// AMD: tt.dot // AMD-COUNT-2: tt.addptr -// AMD-COUNT-2: tt.load +// AMD: tt.load +// AMD: triton_gpu.local_load +// AMD: tt.load +// AMD: triton_gpu.local_load +// AMD: tt.dot // AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview // AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] // AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview @@ -217,6 +248,8 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // AMD-COUNT-2: triton_gpu.local_dealloc // AMD: scf.yield %[[SEL1]] +// AMD_PREFETCH-LABEL: tt.func @matmul_loop_nested + tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -299,10 +332,10 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // AMD: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] // AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) -// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG9]] -// AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] // AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG6]], %{{.*}} // AMD: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_32]] +// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG9]] +// AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] // AMD: %[[ADDI_34:.*]] = arith.addi %[[ARG8]], %{{.*}} // AMD: %[[CMPI_35:.*]] = arith.cmpi slt, %[[ADDI_34]], %{{.*}} // AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_35]], %[[ADDI_34]], %{{.*}} @@ -311,6 +344,22 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // AMD: scf.yield %[[ADDPTR_32]], %[[DOT_31]], %[[SELECT_36]], %[[MEMDESC_SUBVIEW_37]] // AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] +// AMD_PREFETCH-LABEL: tt.func @matmul_loop_single_pipeline +// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return + tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -365,83 +414,110 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} // AMD-LABEL: tt.func @indirect_bmm_scalar -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc -// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} -// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] -// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] -// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] -// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] -// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] -// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] -// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} -// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} -// AMD: %[[ADDPTR_13:.*]] = tt.addptr %{{.*}}, %{{.*}} -// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_11]] -// AMD: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_14]] -// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_13]], %[[CMPI_11]] -// AMD: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[LOAD_16]] -// AMD: %[[SPLAT_18:.*]] = tt.splat %[[MULI_17]] -// AMD: %[[ADDPTR_19:.*]] = tt.addptr %{{.*}}, %[[SPLAT_18]] -// AMD: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_11]] -// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_19]], %[[SPLAT_20]] -// AMD: %[[MEMDESC_SUBVIEW_22:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_22]] -// AMD: %[[MEMDESC_SUBVIEW_23:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_23]] -// AMD: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} -// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_24]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_12]], %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_22]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_23]], %[[ARG13:.*]] = %[[LOAD_15]], %[[ARG14:.*]] = %[[LOAD_21]]) -// AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[ARG12]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[ARG7]] +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] +// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_12]] +// AMD: %[[MEMDESC_SUBVIEW_13:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_13]] +// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_15:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_17:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_16]] +// AMD: %[[LOAD_18:.*]] = tt.load %[[ADDPTR_15]], %[[CMPI_11]] +// AMD: %[[MULI_19:.*]] = arith.muli %{{.*}}, %[[LOAD_18]] +// AMD: %[[SPLAT_20:.*]] = tt.splat %[[MULI_19]] +// AMD: %[[ADDPTR_21:.*]] = tt.addptr %{{.*}}, %[[SPLAT_20]] +// AMD: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_23:.*]] = tt.load %[[ADDPTR_21]], %[[SPLAT_22]] +// AMD: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_24]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_14]], %[[ARG9:.*]] = %[[ADDPTR_15]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[LOAD_17]], %[[ARG12:.*]] = %[[LOAD_23]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_12]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_13]]) +// AMD: %[[ADDI_41:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_42:.*]] = arith.cmpi slt, %[[ADDI_41]], %{{.*}} +// AMD: %[[SELECT_43:.*]] = arith.select %[[CMPI_42]], %[[ADDI_41]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_43]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG11]], %[[MEMDESC_SUBVIEW_44]] +// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_43]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_45]] // AMD: %[[ADDPTR_46:.*]] = tt.addptr %[[ARG8]], %{{.*}} // AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG9]], %{{.*}} // AMD: %[[LOAD_48:.*]] = tt.load %[[ADDPTR_46]] -// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] -// AMD: %[[MULI_50:.*]] = arith.muli %{{.*}}, %[[LOAD_49]] -// AMD: %[[SPLAT_51:.*]] = tt.splat %[[MULI_50]] -// AMD: %[[ADDPTR_52:.*]] = tt.addptr %{{.*}}, %[[SPLAT_51]] -// AMD: %[[LOAD_53:.*]] = tt.load %[[ADDPTR_52]] -// AMD: %[[ADDI_54:.*]] = arith.addi %[[ARG10]], %{{.*}} -// AMD: %[[CMPI_55:.*]] = arith.cmpi slt, %[[ADDI_54]], %{{.*}} -// AMD: %[[SELECT_56:.*]] = arith.select %[[CMPI_55]], %[[ADDI_54]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_57:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_56]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_57]] -// AMD: %[[MEMDESC_SUBVIEW_58:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_56]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_58]] -// AMD: scf.yield %[[DOT_45]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_56]], %[[MEMDESC_SUBVIEW_57]], %[[MEMDESC_SUBVIEW_58]], %[[LOAD_48]], %[[LOAD_53]] -// AMD: } -// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5 -// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]] -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[LOCAL_LOAD_29]], %{{.*}}#0 -// AMD: scf.yield %[[DOT_41]] -// AMD: } else { -// AMD: scf.yield %{{.*}}#0 -// AMD: } -// AMD: %[[ADDI_31:.*]] = arith.addi %{{.*}}#3, %{{.*}} -// AMD: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}} -// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_34:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_33]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_34]] -// AMD: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_33]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_35]] -// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_34]] -// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_35]] -// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] -// AMD: scf.yield %[[DOT_41]] -// AMD: } else { -// AMD: scf.yield %[[SELECT_36]] -// AMD: } -// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] +// AMD: %[[LOCAL_LOAD_49:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[LOAD_50:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[MULI_51:.*]] = arith.muli %{{.*}}, %[[LOAD_50]] +// AMD: %[[SPLAT_52:.*]] = tt.splat %[[MULI_51]] +// AMD: %[[ADDPTR_53:.*]] = tt.addptr %{{.*}}, %[[SPLAT_52]] +// AMD: %[[LOAD_54:.*]] = tt.load %[[ADDPTR_53]] +// AMD: %[[LOCAL_LOAD_55:.*]] = triton_gpu.local_load %[[ARG14]] +// AMD: %[[DOT_56:.*]] = tt.dot %[[LOCAL_LOAD_49]], %[[LOCAL_LOAD_55]], %[[ARG7]] +// AMD: scf.yield %[[DOT_56]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_43]], %[[LOAD_48]], %[[LOAD_54]], %[[MEMDESC_SUBVIEW_44]], %[[MEMDESC_SUBVIEW_45]] +// AMD: } +// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[ADDI_28:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} +// AMD: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_30]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %{{.*}}#4, %[[MEMDESC_SUBVIEW_31]] +// AMD: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_30]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %{{.*}}#5, %[[MEMDESC_SUBVIEW_32]] +// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %{{.*}}#6 +// AMD: %[[LOCAL_LOAD_34:.*]] = triton_gpu.local_load %{{.*}}#7 +// AMD: %[[IF_35:.*]] = scf.if %[[CMPI_26]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_33]], %[[LOCAL_LOAD_34]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_35]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_31]] +// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_32]] +// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_36]] +// AMD: } +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] + +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar +// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: tt.return tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %76: index, @@ -492,11 +568,13 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, // AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one // AMD-COUNT-4: tt.load // AMD: scf.for -// AMD: tt.dot // AMD: tt.load +// AMD: tt.dot // AMD: triton_gpu.local_store // AMD: scf.yield +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar_dist_one + tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -560,40 +638,42 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // AMD: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} // AMD: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] -// AMD: %[[EXPAND_DIMS_9:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} -// AMD: %[[BROADCAST_10:.*]] = tt.broadcast %[[EXPAND_DIMS_9]] -// AMD: %[[MULI_11:.*]] = arith.muli %{{.*}}, %[[BROADCAST_10]] -// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %[[MULI_11]] -// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_14:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_13]] -// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_5]] -// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_15]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_5]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_9]] +// AMD: %[[EXPAND_DIMS_11:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// AMD: %[[BROADCAST_12:.*]] = tt.broadcast %[[EXPAND_DIMS_11]] +// AMD: %[[MULI_13:.*]] = arith.muli %{{.*}}, %[[BROADCAST_12]] +// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %[[MULI_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_15]] // AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] // AMD: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] // AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_18]] +// AMD: triton_gpu.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] // AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} -// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG13:.*]] = %[[LOAD_16]]) -// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] -// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] -// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} -// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] -// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} -// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] -// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] -// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] -// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] -// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] -// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} -// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} -// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] -// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[LOAD_10]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]]) +// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[LOCAL_LOAD_50:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] +// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} +// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] +// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] +// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] +// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] +// AMD: %[[LOCAL_LOAD_57:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] + +// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_vector tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, @@ -744,7 +824,6 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // COMMON: tt.expand_dims // COMMON: tt.expand_dims // COMMON: tt.expand_dims %arg5 -// COMMON-NEXT: tt.expand_dims %arg5 // COMMON: %[[PTR0:.*]] = tt.splat %arg6 // COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] // COMMON-NEXT: tt.load %[[PTR1]] @@ -976,65 +1055,65 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc // AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc // AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] -// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] -// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} -// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] -// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} -// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] -// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] -// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] -// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] -// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] +// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[LOCAL_LOAD_50:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] +// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} +// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] +// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] +// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] +// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] +// AMD: %[[LOCAL_LOAD_57:.*]] = triton_gpu.local_load %[[ARG13]] +// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] // AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} // AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} // AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] +// AMD: triton_gpu.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] // AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] -// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] -// AMD: } -// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %{{.*}}#5 -// AMD: %[[IF_25:.*]] = scf.if %[[CMPI_21]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_23]], %[[LOCAL_LOAD_24]], %{{.*}}#0 -// AMD: scf.yield %[[DOT_45]] -// AMD: } else { -// AMD: scf.yield %{{.*}}#0 +// AMD: triton_gpu.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] // AMD: } -// AMD: %[[ADDPTR_26:.*]] = tt.addptr %{{.*}}#1, %{{.*}} -// AMD: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_22]] -// AMD: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_27]] -// AMD: %[[EXPAND_DIMS_29:.*]] = tt.expand_dims %{{.*}}#6 {axis = 1 : i32} -// AMD: %[[BROADCAST_30:.*]] = tt.broadcast %[[EXPAND_DIMS_29]] -// AMD: %[[MULI_31:.*]] = arith.muli %{{.*}}, %[[BROADCAST_30]] -// AMD: %[[ADDPTR_32:.*]] = tt.addptr %{{.*}}, %[[MULI_31]] -// AMD: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_22]] -// AMD: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_33]] -// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} -// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} -// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_39]] -// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_25]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]] -// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] -// AMD: scf.yield %[[DOT_45]] -// AMD: } else { -// AMD: scf.yield %[[SELECT_40]] -// AMD: } -// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] -// AMD: triton_gpu.local_dealloc %{{.*}} -// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_23:.*]] = tt.addptr %{{.*}}#1, %{{.*}} +// AMD: %[[SPLAT_24:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_23]], %[[SPLAT_24]] +// AMD: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[EXPAND_DIMS_27:.*]] = tt.expand_dims %{{.*}}#5 {axis = 1 : i32} +// AMD: %[[BROADCAST_28:.*]] = tt.broadcast %[[EXPAND_DIMS_27]] +// AMD: %[[MULI_29:.*]] = arith.muli %{{.*}}, %[[BROADCAST_28]] +// AMD: %[[ADDPTR_30:.*]] = tt.addptr %{{.*}}, %[[MULI_29]] +// AMD: %[[SPLAT_31:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_30]], %[[SPLAT_31]] +// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %{{.*}}#6 +// AMD: %[[IF_34:.*]] = scf.if %[[CMPI_21]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_26]], %[[LOCAL_LOAD_33]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_34]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_40]] +// AMD: } +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] +// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: triton_gpu.local_dealloc %{{.*}} #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -1242,6 +1321,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: triton_gpu.local_store // AMD: scf.yield // AMD: triton_gpu.local_dealloc + +// AMD_PREFETCH-LABEL: tt.func public @nested_loops +// AMD_PREFETCH-NOT: triton_gpu.local_alloc +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: triton_gpu.local_alloc +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: triton_gpu.local_store +// AMD_PREFETCH: tt.load +// AMD_PREFETCH: tt.dot +// AMD_PREFETCH: triton_gpu.local_load +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: triton_gpu.local_dealloc + #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> @@ -1571,10 +1667,25 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, // AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: scf.for -// AMD: arith.select -// AMD: arith.addf // AMD: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: arith.addf +// AMD: arith.select +// AMD: scf.yield + +// AMD_PREFETCH-LABEL: @masked_add_kernel +// AMD_PREFETCH: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// AMD_PREFETCH-COUNT-6: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: scf.for +// AMD_PREFETCH: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD_PREFETCH: arith.addf +// AMD_PREFETCH: arith.select +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: scf.yield +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: tt.store +// AMD_PREFETCH: tt.store #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 67dc3f2b640b..dfd2aac21e9b 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -221,7 +221,8 @@ def make_ttgir(mod, metadata, options): "num_stages == 0. Now it will not happen anymore; " "please update to use num_stages == 2 for " "equivalent behavior in the past.") - amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages) + prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1" + amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages, prefetch) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.insert_instruction_sched_hints(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index d0ffdae28ed5..636743d305f9 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -7,7 +7,8 @@ namespace mlir { -std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2); +std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2, + int prefetch = 0); std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 93345b0d6de4..85604dcaca18 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -18,7 +18,10 @@ def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir let options = [ Option<"numStages", "num_stages", "int32_t", /*default*/"2", - "Number of Pipeline stages"> + "Number of Pipeline stages">, + Option<"prefetch", "prefetch", + "int32_t", /*default*/"0", + "Enable prefetch from shared memory"> ]; } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 9371c8b5f897..6be184fe1eb1 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -366,10 +366,10 @@ struct TritonAMDGPUReorderInstructionsPass moveUpTranspose(m); - if (isPureMatmulProblem(m)) { - scheduleGlobalLoadLocalStore(m); + if (isPureMatmulProblem(m)) sinkSecondLoad(m); - } + else + scheduleGlobalLoadLocalStore(m); } }; } // namespace diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index 3b4935026c3f..1a4dd8227c73 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -19,7 +19,7 @@ // modulo schedule and an expander that rewrites the loop and emits a prologue // and epilogue. This pass first calls a helper that will pre-process the IR // to create stream operations and create a modulo schedule. Then we call the -// expander to generate the prologue and new loop. +// expander to generate the prologue and new loop and epilogue. //===----------------------------------------------------------------------===// #define GEN_PASS_CLASSES @@ -56,42 +56,117 @@ static Operation *streamPredication(RewriterBase &rewriter, Operation *op, namespace { -// Encapsulate stream pipelining -// For each `scf.for` create a StreamPipeliner manager. +//===----------------------------------------------------------------------===// +// Software pipelining generally works by anchoring on global load ops in the +// main loop and rotating the loop to schedule global load ops for future loop +// iterations together with compute for the current iteration. In this way, we +// can 1) issue memory operations earlier to hide the latency and 2) break the +// strong dependency inside on loop iteration to give backends flexiblity to +// better interleave instructions for better instruction-level parallelism. +// +// This StreamPipeliner class creates the pipelining schedule and calls the +// PipelineExpander to rewrite the `scf.for` loop accordingly. A schedule +// consists of multiple stages, where ops from different stages can overlap +// executions because the dependencies are loop carried. +// +// The general flow of this process is: +// +// 1. The user provides a `num_stages` that specifies how many stages the +// pipeline will have. The number of stages must be larger than the distance +// from the first independent load to the compute in order to pipeline. +// 2. A schedule is created based on the distance between the global loads +// in the first stages and the compute that uses the loaded values in the +// last stage (num_stages - 1). Each operation will be clustered in the +// order to best overlap with other operations (see details below in the +// initSchedule method). +// 3. When the compute is a tt.dot, the scheduler will insert a shared +// memory allocation between the global load and tt.dot. The ttg.local_store +// will save the global load value to shared memory and the ttg.local_load +// will load the relevant tiles for the tt.dot. These operations will be +// scheduled according to various scheduling schemes outlined below in the +// initSchedule method (see details there). +// 4. Finally the schedule will be passed to the PipelineExpander to rewrite +// accordingly. The new implementation will consist of: +// a. Prologue: containing the ramp-up of num_stages-1 stages for +// iteratorions i=[0, num_stages-1). +// b. New loop: ordered by cluster and iterated on each operation by +// `i + (num_stages-op_stage)`. +// c. Epilogue: ramp-down of the last `num_stages-1` iterations for the +// ops in stages 1 to last_stage. This must consider that the loop +// bounds may be shorter than num_stages. In this case, the epilogue +// iterations must align with the prologue. +// class StreamPipeliner { public: - StreamPipeliner(scf::ForOp _forOp, int _numStages) - : forOp(_forOp), schedule(_numStages), numStages(_numStages), + StreamPipeliner(scf::ForOp _forOp, int _numStages, bool _prefetch) + : forOp(_forOp), prefetch(_prefetch), numStages(_numStages + prefetch), + schedule(numStages), axisInfoAnalysis(forOp->getParentOfType()) { options.supportDynamicLoops = true; options.peelEpilogue = true; options.predicateFn = streamPredication; } + LogicalResult pipelineLoop(); + +private: + void initSchedule(int maxIndirectionLevel); + void computeLoadOpsToIndirectionLevelAndUse(); void assignMemoryLayouts(); - void scheduleLoads(DenseSet &rootUsers); + LogicalResult scheduleLoads(DenseSet &rootUsers); void scheduleDependencies(); void scheduleDistanceOneDependencies(); - void scheduleRemainingToLastStage(tt::CoarseSchedule::Cluster afterPrologue); + void scheduleRemainingToLastStage(); - bool preprocessLoopAndBuildSchedule(); - bool pipelineLoop(); + LogicalResult preprocessLoopAndBuildSchedule(); Value createAlloc(Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, unsigned numBuffers); - void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx, - tt::CoarseSchedule::Cluster prefetchCluster); + void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx); void createStreamOps(); + // Define categories of scheduling details per Operation types. + // The StreamPipeliner schedules 5 types of operations: + // 1. GLOBAL_LOAD: tt.load + // 2. LOCAL_STORE: ttg.local_store (created by the StreamPipeliner) + // 3. LOCAL_LOAD: ttg.local_load (created by the StreamPipeliner) + // 4. COMPUTE: ops that use the loaded data + // 5. TAIL: everything else in the loop + enum SchedType { + SCHED_GLOBAL_LOAD, + SCHED_LOCAL_STORE, + SCHED_LOCAL_LOAD, + SCHED_COMPUTE, + SCHED_TAIL + }; + + void scheduleOp(Operation *op, SchedType type, int stage = -1) { + if (stage < 0) + stage = config[type].stage; + schedule.insert(op, stage, config[type].cluster); + } + private: + // Data members scf::ForOp forOp; - tt::CoarseSchedule schedule; + + // User settings + bool prefetch; int numStages; + // Scheduling clusters + tt::CoarseSchedule schedule; + + // ScheduleConfig lookup by SchedType to get the stage and cluster. + struct ScheduleConfig { + int stage; + tt::CoarseSchedule::Cluster cluster; + }; + SmallVector config; + // Mapping and indirection level for each `tt.load` to its use. - llvm::SmallVector> - loadOpToIndLevelAndUse; + SmallVector> loadOpToIndLevelAndUse; struct LoadInfo { // Shared layout is used for loads feeding into dot ops. @@ -116,9 +191,64 @@ class StreamPipeliner { } // namespace -void StreamPipeliner::createStreamCopy( - tt::LoadOp loadOp, Value alloc, Value extractIdx, - tt::CoarseSchedule::Cluster prefetchCluster) { +// Init Schedule Config based on settings and loop characteristics. +// Create clusters in order of ops in loop. This can interleave ops +// from different stages in the same cluster to achieve better backend +// scheduling. +// WARNING: Changing the order of schedule.clusters.newAtBack() calls +// can cause invalid schedules to be produced. +void StreamPipeliner::initSchedule(int maxIndirectionLevel) { + int lastStage = numStages - 1; + config.resize(SCHED_TAIL + 1); + + bool isMultibuf = numStages > (2 + maxIndirectionLevel); + if (prefetch) { + // Prefetch Schema cluster order and staging. + // for i in (...): + // local_stores: stage=i+1 + // global_loads: stage=i+2 + // compute: stage=i + // local_load: stage=i+1 + // tail: stage=i + config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; + auto cluster1 = schedule.clusters.newAtBack(); + config[SCHED_GLOBAL_LOAD] = {0, cluster1}; + config[SCHED_COMPUTE] = {lastStage, cluster1}; + config[SCHED_LOCAL_LOAD] = {lastStage - 1, schedule.clusters.newAtBack()}; + config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; + } else if (isMultibuf) { + // Streaming Schema cluster order and staging for multi-buffer. + // for i in (...): + // local_stores: stage=i+1 + // global_loads: stage=i+2 + // local_load: stage=i + // compute: stage=i + // tail: stage=i + config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; + auto cluster1 = schedule.clusters.newAtBack(); + config[SCHED_GLOBAL_LOAD] = {0, cluster1}; + config[SCHED_LOCAL_LOAD] = {lastStage, cluster1}; + config[SCHED_COMPUTE] = {lastStage, cluster1}; + config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; + } else { + // Streaming Schema cluster order and staging for single-buffer. + // for i in (...): + // global_loads: stage=i+1 + // local_load: stage=i + // compute: stage=i + // local_stores: stage=i+1 + // tail: stage=i + auto cluster0 = schedule.clusters.newAtBack(); + config[SCHED_GLOBAL_LOAD] = {0, cluster0}; + config[SCHED_LOCAL_LOAD] = {lastStage, schedule.clusters.newAtBack()}; + config[SCHED_COMPUTE] = {lastStage, cluster0}; + config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; + config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; + } +} + +void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { OpBuilder builder(forOp); Value zero = builder.create(forOp.getLoc(), 0, 32); // Replace the load with insert/extract slice. @@ -126,6 +256,7 @@ void StreamPipeliner::createStreamCopy( Location loc = loadOp.getLoc(); Value src = loadOp.getPtr(); Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); @@ -145,8 +276,6 @@ void StreamPipeliner::createStreamCopy( allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); - auto storeOp = - builder.create(loc, copy->getResult(0), viewLoad); // Clean up old local caches. SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { @@ -158,17 +287,18 @@ void StreamPipeliner::createStreamCopy( for (auto alloc : allocsToErase) alloc.erase(); + // Prefetch load ahead of the dot stage if is used by the dot. + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + scheduleOp(viewLoad, SCHED_LOCAL_STORE); + scheduleOp(storeOp, SCHED_LOCAL_STORE); + + // Create local load auto sharedLoad = builder.create(loc, loadOp.getType(), viewLoad); - auto result = sharedLoad->getResults(); - - // Create a select for non-zero other values. - Value other = loadOp.getOther(); - if (other && !isZeroConst(other)) { - auto select = builder.create( - loc, loadOp.getType(), mask, sharedLoad.getResult(), other); - result = select->getResults(); - } + Value result = sharedLoad.getResult(); + if (prefetch) + scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); // If the currently processed `LoadOp` is labeled with an index regarding // to which `DotOp` operand the corresponding data belongs to, then label the @@ -179,14 +309,13 @@ void StreamPipeliner::createStreamCopy( storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); } - loadOp->replaceAllUsesWith(result); + loadOp->replaceAllUsesWith(ValueRange{result}); - // Prefetch load ahead of the dot stage if is used by the dot. - if (loadToInfo[loadOp].usedByDot) { - assert(numStages >= 2 && "requires num_stages=2 at least"); - schedule.insert(storeOp, numStages - 2, prefetchCluster); - schedule.insert(viewLoad, numStages - 2, prefetchCluster); + if (prefetch && result.hasOneUse()) { + if (auto cvt = dyn_cast(*result.getUsers().begin())) + scheduleOp(cvt, SCHED_LOCAL_LOAD); } + loadOp.erase(); } @@ -348,7 +477,7 @@ void StreamPipeliner::assignMemoryLayouts() { } } -void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { +LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { // Get all loads that are (transitively) used by dot ops and their distance // to the dot op. computeLoadOpsToIndirectionLevelAndUse(); @@ -361,12 +490,12 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { } }); if (loadOpToIndLevelAndUse.empty()) - return; + return failure(); // Check which loads are good for pipelining, and assign them memory layouts. assignMemoryLayouts(); if (loadToInfo.empty()) - return; + return failure(); // Filter out load ops that cannot be pipelined. int resize = 0; @@ -382,6 +511,12 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + LDBG("maxIndirectionLevel = " << maxIndirectionLevel); + if (maxIndirectionLevel >= numStages) + return failure(); + + initSchedule(maxIndirectionLevel); + // The stage gap between chained loads--this allows us to "spread" loads // with a non-one step in case the number of stages given by the user is // large. @@ -391,24 +526,18 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { LDBG("stagesBetweenLoads = " << stagesBetweenLoads); // Put the root uses of the loads in the last stage. - tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). if (!isa(use)) { - schedule.insert(use, numStages - 1, rootUsersCluster); + scheduleOp(use, SCHED_COMPUTE); rootUsers.insert(use); } } - // Create a cluster for load ops at each indirection level. - SmallVector loadsClusters; - for (int i = 0; i <= maxIndirectionLevel; i++) { - loadsClusters.push_back(schedule.clusters.newAtBack()); - } // Assign stages to the loads. for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; - schedule.insert(loadOp, stage, loadsClusters[indLevel]); + scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); } // Calculate distance from the load to the use. @@ -424,6 +553,8 @@ void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { LDBG(" usedByDot: " << info.usedByDot); } }); + + return success(); } // Add dependencies of anchor ops to the coarse schedule. Schedule them to @@ -490,22 +621,23 @@ void StreamPipeliner::scheduleDistanceOneDependencies() { } } -void StreamPipeliner::scheduleRemainingToLastStage( - tt::CoarseSchedule::Cluster afterPrologue) { +void StreamPipeliner::scheduleRemainingToLastStage() { + int lastStage = numStages - 1; // Assign the rest of the ops to the last stage. // Take care of the ordering of the ops - uses cannot be scheduled to the // cluster before the definition. DenseMap opToCluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) { - opToCluster[&op] = afterPrologue; + auto schedType = isa(op) ? SCHED_COMPUTE : SCHED_TAIL; + opToCluster[&op] = config[schedType].cluster; } } SmallVector queue; for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { // We really only care about the producers from the last stage. // Others will be scheduled before these ops anyway. - if (stage == numStages - 1) { + if (stage == lastStage) { queue.push_back(op); } } @@ -523,7 +655,7 @@ void StreamPipeliner::scheduleRemainingToLastStage( } } for (auto [op, cluster] : opToCluster) { - schedule.insert(op, numStages - 1, cluster); + schedule.insert(op, lastStage, cluster); } } @@ -540,8 +672,10 @@ Value StreamPipeliner::createAlloc(Operation *loadOp, Type memdescType = tt::MemDescType::get(bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, /*mutableMemory=*/true); - return builder.create(loadOp->getLoc(), memdescType, - Value()); + auto alloc = + builder.create(loadOp->getLoc(), memdescType, Value()); + sharedMemAllocs.push_back(alloc); + return alloc; } // Convert load ops into shared memory allocation loads and apply @@ -549,19 +683,20 @@ Value StreamPipeliner::createAlloc(Operation *loadOp, void StreamPipeliner::createStreamOps() { // Calculate the number of buffers needed for each load. // TODO: Use the precise number of buffers needed by the particular load. - int numBuffers = -1; - for (auto &[_, info] : loadToInfo) - numBuffers = std::max(numBuffers, info.distToUse); - LDBG("deduced shared memory buffer number = " << numBuffers); + int maxNumBuffers = -1; + for (auto &[_, info] : loadToInfo) { + int sharedBuffers = info.distToUse - (info.usedByDot ? prefetch : 0); + maxNumBuffers = std::max(maxNumBuffers, sharedBuffers); + } + LDBG("deduced max shared memory buffer number = " << maxNumBuffers); SmallVector> loadToAllocs; for (auto &[loadOp, info] : loadToInfo) { if (!info.sharedEncoding) continue; - Value alloc = createAlloc(loadOp, info.sharedEncoding, numBuffers); + Value alloc = createAlloc(loadOp, info.sharedEncoding, maxNumBuffers); assert(alloc && "Failed to create alloc for the async load."); - sharedMemAllocs.push_back(alloc); loadToAllocs.emplace_back(loadOp, alloc); } @@ -574,7 +709,7 @@ void StreamPipeliner::createStreamOps() { Value one = builder.create(loc, 1, 32); Value extractIdx = minusOne; Value numBuffersVal = - builder.create(loc, numBuffers, 32); + builder.create(loc, maxNumBuffers, 32); unsigned newOperandIndex = forOp.getBody()->getNumArguments(); // Patch the loop to add the new loop carried dependencies. @@ -593,24 +728,23 @@ void StreamPipeliner::createStreamOps() { extractIdx, numBuffersVal); extractIdx = builder.create(loc, cndExt, extractIdx, zero); - // Create a cluster for prefetching global reads for the dot. - tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); - + // Create stream copies. for (auto &[op, alloc] : loadToAllocs) { if (auto loadOp = dyn_cast(op)) - createStreamCopy(loadOp, alloc, extractIdx, prefetchCluster); + createStreamCopy(loadOp, alloc, extractIdx); } // Patch the yield with the updated counters. appendToForOpYield(forOp, {extractIdx}); } -bool StreamPipeliner::preprocessLoopAndBuildSchedule() { +LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; - scheduleLoads(rootUsers); + if (failed(scheduleLoads(rootUsers))) + return failure(); if (loadToInfo.empty()) - return false; + return failure(); LLVM_DEBUG({ LDBG("Coarse schedule loads only:"); @@ -620,13 +754,6 @@ bool StreamPipeliner::preprocessLoopAndBuildSchedule() { // Convert the loads into shared memory allocations and loads from them. createStreamOps(); - LLVM_DEBUG({ - LDBG("Coarse schedule with stream loads:"); - schedule.dump(); - }); - - tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); - scheduleDependencies(); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); @@ -639,7 +766,7 @@ bool StreamPipeliner::preprocessLoopAndBuildSchedule() { schedule.dump(); }); - scheduleRemainingToLastStage(afterPrologue); + scheduleRemainingToLastStage(); LLVM_DEBUG({ LDBG("Final coarse schedule:"); schedule.dump(); @@ -662,7 +789,18 @@ bool StreamPipeliner::preprocessLoopAndBuildSchedule() { // Explicitly deallocate created allocations. for (auto alloc : sharedMemAllocs) builder.create(forOp.getLoc(), alloc); - return true; + + return success(); +} + +LogicalResult StreamPipeliner::pipelineLoop() { + if (failed(preprocessLoopAndBuildSchedule())) + return failure(); + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return tt::pipelineForLoop(rewriter, forOp, options); } // Return true if the preconditions for pipelining the loop are met. @@ -682,19 +820,6 @@ static bool checkPrecondition(scf::ForOp forOp) { return !forOp->walk(hasNestedLoopInside).wasInterrupted(); } -bool StreamPipeliner::pipelineLoop() { - if (!checkPrecondition(forOp)) - return false; - - if (!preprocessLoopAndBuildSchedule()) - return false; - LDBG("Loop before sending to expander:\n" << *forOp); - - IRRewriter rewriter(forOp->getContext()); - rewriter.setInsertionPoint(forOp); - return succeeded(tt::pipelineForLoop(rewriter, forOp, options)); -} - namespace { // Go through a single use chain to get the result of the target op after all // unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. @@ -733,7 +858,10 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) { struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { PipelinePass() = default; - PipelinePass(int32_t numStages) { this->numStages = numStages; } + PipelinePass(int32_t numStages, int32_t prefetch) { + this->numStages = numStages; + this->prefetch = prefetch; + } void runOnOperation() override { SmallVector loops; @@ -745,8 +873,11 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { }); for (scf::ForOp forOp : loops) { - StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp)); - sp.pipelineLoop(); + if (!checkPrecondition(forOp)) + continue; + StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp), prefetch); + if (failed(sp.pipelineLoop())) + continue; } } @@ -762,6 +893,6 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { } // anonymous namespace std::unique_ptr -mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages) { - return std::make_unique(numStages); +mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages, int prefetch) { + return std::make_unique(numStages, prefetch); } diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 9eab3771263e..a9bd3e9b7fb7 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -72,8 +72,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUConvertToBufferOpsPass); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); - ADD_PASS_WRAPPER_1("add_stream_pipelinev2", - mlir::createTritonAMDGPUStreamPipelineV2Pass, int); + ADD_PASS_WRAPPER_2("add_stream_pipelinev2", + mlir::createTritonAMDGPUStreamPipelineV2Pass, int, int); } void addControlConstant(llvm::Module *module, const char *name,