From 6e03ce850924b7b0e1704f5b6df9982cd4a14508 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 14 Nov 2024 23:45:07 +0000 Subject: [PATCH] Revert "[AMD][Pipeliner] Improve clustering and add prefetch (#4881)" This reverts commit cc25374fa480c0b3e51cf218ed6fe7eb4c50a5bb. --- .../amd/amd-reorder-instructions.mlir | 345 +++++++++++++ test/TritonGPU/amd/amd-sched-2nd-load.mlir | 32 +- test/TritonGPU/loop-pipeline-hip.mlir | 4 +- test/TritonGPU/loop-pipeline.mlir | 475 +++++++----------- third_party/amd/backend/compiler.py | 3 +- .../include/TritonAMDGPUTransforms/Passes.h | 3 +- .../include/TritonAMDGPUTransforms/Passes.td | 5 +- .../ReorderInstructions.cpp | 8 +- .../StreamPipelineV2.cpp | 317 ++++-------- third_party/amd/python/triton_amd.cc | 4 +- 10 files changed, 648 insertions(+), 548 deletions(-) diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index d7be023312ea..5dfd0f2a5f4c 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -115,6 +115,351 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war } } +// ----- +// Move loads (and independent local_stores) as early as possible. +// For example in the matmul_loop below, the scf.for loop looks like this after pipeliner: +// scf.for ... { +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=2 : +// scf.for ... { +// // stage 0.a +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0.b +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=3 (double-buffered) : +// scf.for ... { +// // stage 1 +// tt.local_store %a_next_1 +// tt.local_store %b_next_1 +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next_2 = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next_2 = tt.load %bptr +// // stage 2 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// yield +// } + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) +// Stage 0.a +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] +// Stage 1 +// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} +// CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]] +// Stage 0.b +// CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} +// CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] +// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: } + + tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %11 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %20 = arith.subi %arg1, %arg2 : index + %21 = arith.cmpi slt, %arg5, %20 : index + %22 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %23 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = arith.mulf %23, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> + %29 = tt.load %26, %28 : tensor<128x32x!tt.ptr, #blocked1> + %30 = tt.splat %21 : i1 -> tensor<32x128xi1, #blocked> + %31 = tt.load %27, %30, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %32 = arith.addi %arg9, %c1_i32 : i32 + %33 = arith.cmpi slt, %32, %c1_i32 : i32 + %34 = arith.select %33, %32, %c0_i32 : i32 + %35 = triton_gpu.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %36 = triton_gpu.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %10 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %11 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %19#2 : tensor<128x128xf32, #mma> + } + + +// This example tests that tt.load overlaps with independent ttg.local_store which +// overlaps with independent tt.dot. +// num_stages == 3, double buffered + +// CHECK-LABEL: tt.func @matmul_loop_mb +// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 0 +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]] +// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]] +// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]] +// Stage 1 +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} +// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] +// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] +// CHECK: } + + tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c2 = arith.constant 2 : index + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %11 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = arith.addi %arg0, %arg2 : index + %18 = arith.cmpi slt, %17, %arg1 : index + %19 = tt.addptr %4, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %20 = tt.addptr %9, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %21 = tt.splat %18 : i1 -> tensor<128x32xi1, #blocked1> + %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> + %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> + %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %25 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %26 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { + %28 = arith.muli %arg2, %c2 : index + %29 = arith.subi %arg1, %28 : index + %30 = arith.cmpi slt, %arg5, %29 : index + %31 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = arith.mulf %32, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1> + %38 = tt.load %35, %37 : tensor<128x32x!tt.ptr, #blocked1> + %39 = tt.splat %30 : i1 -> tensor<32x128xi1, #blocked> + %40 = tt.load %36, %39, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %41 = arith.addi %arg9, %c1_i32 : i32 + %42 = arith.cmpi slt, %41, %c2_i32 : i32 + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = triton_gpu.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %45 = triton_gpu.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> + } + triton_gpu.local_dealloc %10 : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %11 : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %27#2 : tensor<128x128xf32, #mma> + } + +// This example shows dependent loads and verifies all are moved early. +// CHECK-LABEL: tt.func @indirect_bmm_vector +// CHECK: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 0 +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// Stage 1.a +// CHECK: %[[EXPAND_DIMS_25:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]] +// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %[[BROADCAST_26]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[MULI_27]] +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SUBI_32:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_32]] +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]] +// Stage 1.b +// CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} +// CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] +// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] +// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] +// CHECK: } + + tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %2 = arith.cmpi sgt, %arg1, %c0 : index + %3 = tt.splat %2 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = arith.cmpi sgt, %arg1, %c1 : index + %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1> + %8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr, #blocked1> + %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked> + %12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked> + %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> + %15 = tt.splat %5 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %17 = triton_gpu.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %20 = arith.subi %arg1, %c2 : index + %21 = arith.cmpi slt, %arg6, %20 : index + %22 = arith.subi %arg1, %c1 : index + %23 = arith.cmpi slt, %arg6, %22 : index + %24 = triton_gpu.local_load %arg11 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1> + %30 = tt.load %27, %29 : tensor<16x16x!tt.ptr, #blocked1> + %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked> + %34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked> + %36 = tt.load %34, %35 : tensor<16x16x!tt.ptr, #blocked> + %37 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %39 = arith.addi %arg10, %c1_i32 : i32 + %40 = arith.cmpi slt, %39, %c1_i32 : i32 + %41 = arith.select %40, %39, %c0_i32 : i32 + %42 = triton_gpu.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %43 = triton_gpu.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + triton_gpu.local_dealloc %0 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + tt.return %19#0 : tensor<16x16xf32, #mma> + } +} + // ----- // CHECK-LABEL: sink_convert_dealloc diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir index 312025a8de74..6afa25259067 100644 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -35,11 +35,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !tt.memdesc<128x256xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> @@ -74,11 +74,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !tt.memdesc<64x256xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> @@ -101,9 +101,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Should NOT apply: tile size 256x64x128 with single dot // CHECK-LABEL: sink_2nd_load_256x64x128 // CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load // CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load +// CHECK-NEXT: local_load // CHECK-NEXT: tt.dot // CHECK-NEXT: triton_gpu.local_store %[[tileA]] // CHECK-NEXT: triton_gpu.local_store %[[tileB]] @@ -113,11 +113,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { - %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> - %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !tt.memdesc<128x64xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x64xf32, #mma> @@ -140,9 +140,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Should NOT apply: tile size 256x256x32 with single dot // CHECK-LABEL: sink_2nd_load_256x256x32 // CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load // CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load +// CHECK-NEXT: local_load // CHECK-NEXT: tt.dot // CHECK-NEXT: triton_gpu.local_store %[[tileA]] // CHECK-NEXT: triton_gpu.local_store %[[tileB]] @@ -152,11 +152,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { - %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> - %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> + %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> triton_gpu.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !tt.memdesc<256x32xf16, #shared, #triton_gpu.shared_memory, mutable> triton_gpu.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !tt.memdesc<32x256xf16, #shared1, #triton_gpu.shared_memory, mutable> scf.yield %3 : tensor<256x256xf32, #mma> @@ -181,9 +181,9 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Should NOT apply: the 2nd load has a user before the dot // CHECK-LABEL: sink_2nd_load_128x128x128_user_before_dot // CHECK: %[[tileA:.*]] = tt.load -// CHECK-NEXT: local_load // CHECK-NEXT: %[[tileB:.*]] = tt.load // CHECK-NEXT: local_load +// CHECK-NEXT: local_load // CHECK-NEXT: tt.store // CHECK-NEXT: tt.dot // CHECK-NEXT: triton_gpu.local_store %[[tileA]] @@ -193,10 +193,10 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { - %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> %1 = triton_gpu.local_load %A_LDS : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> - %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> %2 = triton_gpu.local_load %B_LDS : !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> + %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> + %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> triton_gpu.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory, mutable> @@ -213,12 +213,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war // Category 3: two dots in the for loop. Make sure the optimization is not applied // should NOT apply: two dots // CHECK-LABEL: sink_2nd_load_256x256x64_two_dot -// CHECK: tt.load -// CHECK-NEXT: tt.load -// CHECK-NEXT: triton_gpu.local_load +// CHECK: triton_gpu.local_load // CHECK-NEXT: triton_gpu.local_load // CHECK-NEXT: tt.dot // CHECK-NEXT: tt.dot +// CHECK-NEXT: tt.load +// CHECK-NEXT: tt.load // CHECK-NEXT: triton_gpu.local_store // CHECK-NEXT: triton_gpu.local_store #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index d6653b2b004c..3abcc581b906 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -35,9 +35,9 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: triton_gpu.local_store // CHECK: scf.for - // CHECK: tt.load // CHECK: tt.dot // CHECK: tt.dot + // CHECK: tt.load // CHECK: triton_gpu.local_store // CHECK: scf.yield %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { @@ -165,9 +165,9 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 // CHECK-LABEL: tt.func public @add_barrier_kernel // CHECK: tt.load // CHECK: scf.for -// CHECK: tt.load // CHECK: gpu.barrier // CHECK: tt.store +// CHECK: tt.load // CHECK: scf.yield // CHECK: gpu.barrier // CHECK: tt.store diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 82f726899ea4..c0c4c1860e34 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,6 +1,5 @@ // RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK // RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD -// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -63,22 +62,22 @@ // AMD-DAG: %[[C0:.*]] = arith.constant 0 : index // AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index // AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) -// AMD: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// AMD: %[[ADDPTR_35:.*]] = tt.addptr %[[ARG7]], %{{.*}} -// AMD: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]] -// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG10]] -// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_35]] -// AMD: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[MULF_40:.*]] = arith.mulf %[[LOCAL_LOAD_39]], %{{.*}} -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[MULF_40]], %[[ARG8]] -// AMD: %[[ADDI_42:.*]] = arith.addi %[[ARG9]], %{{.*}} -// AMD: %[[CMPI_43:.*]] = arith.cmpi slt, %[[ADDI_42]], %{{.*}} -// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_43]], %[[ADDI_42]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_45]] -// AMD: %[[MEMDESC_SUBVIEW_46:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_44]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_46]] -// AMD: scf.yield %[[ADDPTR_34]], %[[ADDPTR_35]], %[[DOT_41]], %[[SELECT_44]], %[[MEMDESC_SUBVIEW_45]], %[[MEMDESC_SUBVIEW_46]] +// AMD: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG10]] +// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[MULF_34:.*]] = arith.mulf %[[LOCAL_LOAD_33]], %{{.*}} +// AMD: %[[DOT_35:.*]] = tt.dot %[[LOCAL_LOAD_32]], %[[MULF_34]], %[[ARG8]] +// AMD: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[ADDPTR_37:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_36]] +// AMD: %[[LOAD_39:.*]] = tt.load %[[ADDPTR_37]] +// AMD: %[[ADDI_40:.*]] = arith.addi %[[ARG9]], %{{.*}} +// AMD: %[[CMPI_41:.*]] = arith.cmpi slt, %[[ADDI_40]], %{{.*}} +// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_41]], %[[ADDI_40]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_43]] +// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_39]], %[[MEMDESC_SUBVIEW_44]] +// AMD: scf.yield %[[ADDPTR_36]], %[[ADDPTR_37]], %[[DOT_35]], %[[SELECT_42]], %[[MEMDESC_SUBVIEW_43]], %[[MEMDESC_SUBVIEW_44]] // AMD: } // AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] // AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]] @@ -100,34 +99,6 @@ // AMD: triton_gpu.local_dealloc %{{.*}} // AMD: triton_gpu.local_dealloc %{{.*}} -// Prefetch pipelining adds another stage in between global load and compute. -// This stage will local_store, then local_load, creating a prefetch from shared -// memory into a register buffer for compute. -// -// AMD_PREFETCH-LABEL: tt.func @matmul_loop -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.yield -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: tt.return - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -229,12 +200,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview // AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: %[[FOR:.*]]:6 = scf.for -// AMD-COUNT-2: tt.addptr -// AMD: tt.load -// AMD: triton_gpu.local_load -// AMD: tt.load -// AMD: triton_gpu.local_load +// AMD-COUNT-2: triton_gpu.local_load // AMD: tt.dot +// AMD-COUNT-2: tt.addptr +// AMD-COUNT-2: tt.load // AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview // AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] // AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview @@ -248,8 +217,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // AMD-COUNT-2: triton_gpu.local_dealloc // AMD: scf.yield %[[SEL1]] -// AMD_PREFETCH-LABEL: tt.func @matmul_loop_nested - tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -332,10 +299,10 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // AMD: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] // AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} // AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) -// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG6]], %{{.*}} -// AMD: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_32]] // AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG9]] // AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_32]] // AMD: %[[ADDI_34:.*]] = arith.addi %[[ARG8]], %{{.*}} // AMD: %[[CMPI_35:.*]] = arith.cmpi slt, %[[ADDI_34]], %{{.*}} // AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_35]], %[[ADDI_34]], %{{.*}} @@ -344,22 +311,6 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // AMD: scf.yield %[[ADDPTR_32]], %[[DOT_31]], %[[SELECT_36]], %[[MEMDESC_SUBVIEW_37]] // AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] -// AMD_PREFETCH-LABEL: tt.func @matmul_loop_single_pipeline -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.yield -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: tt.return - tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -414,110 +365,83 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} // AMD-LABEL: tt.func @indirect_bmm_scalar -// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc -// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc -// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} -// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] -// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] -// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] -// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] -// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] -// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] -// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_12:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_12]] -// AMD: %[[MEMDESC_SUBVIEW_13:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_13]] -// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %{{.*}} -// AMD: %[[ADDPTR_15:.*]] = tt.addptr %{{.*}}, %{{.*}} -// AMD: %[[SPLAT_16:.*]] = tt.splat %[[CMPI_11]] -// AMD: %[[LOAD_17:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_16]] -// AMD: %[[LOAD_18:.*]] = tt.load %[[ADDPTR_15]], %[[CMPI_11]] -// AMD: %[[MULI_19:.*]] = arith.muli %{{.*}}, %[[LOAD_18]] -// AMD: %[[SPLAT_20:.*]] = tt.splat %[[MULI_19]] -// AMD: %[[ADDPTR_21:.*]] = tt.addptr %{{.*}}, %[[SPLAT_20]] -// AMD: %[[SPLAT_22:.*]] = tt.splat %[[CMPI_11]] -// AMD: %[[LOAD_23:.*]] = tt.load %[[ADDPTR_21]], %[[SPLAT_22]] -// AMD: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} -// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_24]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_14]], %[[ARG9:.*]] = %[[ADDPTR_15]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[LOAD_17]], %[[ARG12:.*]] = %[[LOAD_23]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_12]], %[[ARG14:.*]] = %[[MEMDESC_SUBVIEW_13]]) -// AMD: %[[ADDI_41:.*]] = arith.addi %[[ARG10]], %{{.*}} -// AMD: %[[CMPI_42:.*]] = arith.cmpi slt, %[[ADDI_41]], %{{.*}} -// AMD: %[[SELECT_43:.*]] = arith.select %[[CMPI_42]], %[[ADDI_41]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_43]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG11]], %[[MEMDESC_SUBVIEW_44]] -// AMD: %[[MEMDESC_SUBVIEW_45:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_43]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_45]] +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] +// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_13:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_14]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_13]], %[[CMPI_11]] +// AMD: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[LOAD_16]] +// AMD: %[[SPLAT_18:.*]] = tt.splat %[[MULI_17]] +// AMD: %[[ADDPTR_19:.*]] = tt.addptr %{{.*}}, %[[SPLAT_18]] +// AMD: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_19]], %[[SPLAT_20]] +// AMD: %[[MEMDESC_SUBVIEW_22:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_22]] +// AMD: %[[MEMDESC_SUBVIEW_23:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_23]] +// AMD: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_24]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_12]], %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_22]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_23]], %[[ARG13:.*]] = %[[LOAD_15]], %[[ARG14:.*]] = %[[LOAD_21]]) +// AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[ARG7]] // AMD: %[[ADDPTR_46:.*]] = tt.addptr %[[ARG8]], %{{.*}} // AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG9]], %{{.*}} // AMD: %[[LOAD_48:.*]] = tt.load %[[ADDPTR_46]] -// AMD: %[[LOCAL_LOAD_49:.*]] = triton_gpu.local_load %[[ARG13]] -// AMD: %[[LOAD_50:.*]] = tt.load %[[ADDPTR_47]] -// AMD: %[[MULI_51:.*]] = arith.muli %{{.*}}, %[[LOAD_50]] -// AMD: %[[SPLAT_52:.*]] = tt.splat %[[MULI_51]] -// AMD: %[[ADDPTR_53:.*]] = tt.addptr %{{.*}}, %[[SPLAT_52]] -// AMD: %[[LOAD_54:.*]] = tt.load %[[ADDPTR_53]] -// AMD: %[[LOCAL_LOAD_55:.*]] = triton_gpu.local_load %[[ARG14]] -// AMD: %[[DOT_56:.*]] = tt.dot %[[LOCAL_LOAD_49]], %[[LOCAL_LOAD_55]], %[[ARG7]] -// AMD: scf.yield %[[DOT_56]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_43]], %[[LOAD_48]], %[[LOAD_54]], %[[MEMDESC_SUBVIEW_44]], %[[MEMDESC_SUBVIEW_45]] -// AMD: } -// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[ADDI_28:.*]] = arith.addi %{{.*}}#3, %{{.*}} -// AMD: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}} -// AMD: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_31:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_30]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#4, %[[MEMDESC_SUBVIEW_31]] -// AMD: %[[MEMDESC_SUBVIEW_32:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_30]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %{{.*}}#5, %[[MEMDESC_SUBVIEW_32]] -// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %{{.*}}#6 -// AMD: %[[LOCAL_LOAD_34:.*]] = triton_gpu.local_load %{{.*}}#7 -// AMD: %[[IF_35:.*]] = scf.if %[[CMPI_26]] -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_33]], %[[LOCAL_LOAD_34]], %{{.*}}#0 -// AMD: scf.yield %[[DOT_41]] -// AMD: } else { -// AMD: scf.yield %{{.*}}#0 -// AMD: } -// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_35]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_31]] -// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_32]] -// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] -// AMD: scf.yield %[[DOT_41]] -// AMD: } else { -// AMD: scf.yield %[[SELECT_36]] -// AMD: } -// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] -// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] - -// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.yield -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: tt.return +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[MULI_50:.*]] = arith.muli %{{.*}}, %[[LOAD_49]] +// AMD: %[[SPLAT_51:.*]] = tt.splat %[[MULI_50]] +// AMD: %[[ADDPTR_52:.*]] = tt.addptr %{{.*}}, %[[SPLAT_51]] +// AMD: %[[LOAD_53:.*]] = tt.load %[[ADDPTR_52]] +// AMD: %[[ADDI_54:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_55:.*]] = arith.cmpi slt, %[[ADDI_54]], %{{.*}} +// AMD: %[[SELECT_56:.*]] = arith.select %[[CMPI_55]], %[[ADDI_54]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_57:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_56]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_57]] +// AMD: %[[MEMDESC_SUBVIEW_58:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_56]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_58]] +// AMD: scf.yield %[[DOT_45]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_56]], %[[MEMDESC_SUBVIEW_57]], %[[MEMDESC_SUBVIEW_58]], %[[LOAD_48]], %[[LOAD_53]] +// AMD: } +// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[LOCAL_LOAD_29]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[ADDI_31:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}} +// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_34:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_34]] +// AMD: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_35]] +// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_34]] +// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_35]] +// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_36]] +// AMD: } +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %76: index, @@ -568,13 +492,11 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, // AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one // AMD-COUNT-4: tt.load // AMD: scf.for -// AMD: tt.load // AMD: tt.dot +// AMD: tt.load // AMD: triton_gpu.local_store // AMD: scf.yield -// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_scalar_dist_one - tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -638,42 +560,40 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // AMD: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} // AMD: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] // AMD: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] -// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_5]] -// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_9]] -// AMD: %[[EXPAND_DIMS_11:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} -// AMD: %[[BROADCAST_12:.*]] = tt.broadcast %[[EXPAND_DIMS_11]] -// AMD: %[[MULI_13:.*]] = arith.muli %{{.*}}, %[[BROADCAST_12]] -// AMD: %[[ADDPTR_14:.*]] = tt.addptr %{{.*}}, %[[MULI_13]] -// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_2]] -// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_14]], %[[SPLAT_15]] +// AMD: %[[EXPAND_DIMS_9:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// AMD: %[[BROADCAST_10:.*]] = tt.broadcast %[[EXPAND_DIMS_9]] +// AMD: %[[MULI_11:.*]] = arith.muli %{{.*}}, %[[BROADCAST_10]] +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %[[MULI_11]] +// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_14:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_5]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_15]] // AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] // AMD: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] // AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_16]], %[[MEMDESC_SUBVIEW_18]] +// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_18]] // AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} -// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[LOAD_10]], %[[ARG13:.*]] = %[[MEMDESC_SUBVIEW_18]]) -// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} -// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] -// AMD: %[[LOCAL_LOAD_50:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] -// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} -// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] -// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] -// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] -// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] -// AMD: %[[LOCAL_LOAD_57:.*]] = triton_gpu.local_load %[[ARG13]] -// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] -// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} -// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} -// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] -// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] -// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] - -// AMD_PREFETCH-LABEL: tt.func @indirect_bmm_vector +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG13:.*]] = %[[LOAD_16]]) +// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] +// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] +// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] +// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] +// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] +// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] +// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, @@ -824,6 +744,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // COMMON: tt.expand_dims // COMMON: tt.expand_dims // COMMON: tt.expand_dims %arg5 +// COMMON-NEXT: tt.expand_dims %arg5 // COMMON: %[[PTR0:.*]] = tt.splat %arg6 // COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] // COMMON-NEXT: tt.load %[[PTR1]] @@ -1055,65 +976,65 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc // AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc // AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) -// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG8]], %{{.*}} -// AMD: %[[ADDPTR_48:.*]] = tt.addptr %[[ARG9]], %{{.*}} -// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] -// AMD: %[[LOCAL_LOAD_50:.*]] = triton_gpu.local_load %[[ARG11]] -// AMD: %[[LOAD_51:.*]] = tt.load %[[ADDPTR_48]] -// AMD: %[[EXPAND_DIMS_52:.*]] = tt.expand_dims %[[ARG12]] {axis = 1 : i32} -// AMD: %[[BROADCAST_53:.*]] = tt.broadcast %[[EXPAND_DIMS_52]] -// AMD: %[[MULI_54:.*]] = arith.muli %{{.*}}, %[[BROADCAST_53]] -// AMD: %[[ADDPTR_55:.*]] = tt.addptr %{{.*}}, %[[MULI_54]] -// AMD: %[[LOAD_56:.*]] = tt.load %[[ADDPTR_55]] -// AMD: %[[LOCAL_LOAD_57:.*]] = triton_gpu.local_load %[[ARG13]] -// AMD: %[[DOT_58:.*]] = tt.dot %[[LOCAL_LOAD_50]], %[[LOCAL_LOAD_57]], %[[ARG7]] +// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] +// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] +// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] +// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] +// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] +// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] +// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] // AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} // AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} // AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} // AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_49]], %[[MEMDESC_SUBVIEW_62]] +// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] // AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_56]], %[[MEMDESC_SUBVIEW_63]] -// AMD: scf.yield %[[DOT_58]], %[[ADDPTR_47]], %[[ADDPTR_48]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[LOAD_51]], %[[MEMDESC_SUBVIEW_63]] +// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] +// AMD: } +// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[IF_25:.*]] = scf.if %[[CMPI_21]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_23]], %[[LOCAL_LOAD_24]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 // AMD: } -// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} -// AMD: %[[ADDPTR_23:.*]] = tt.addptr %{{.*}}#1, %{{.*}} -// AMD: %[[SPLAT_24:.*]] = tt.splat %[[CMPI_22]] -// AMD: %[[LOAD_25:.*]] = tt.load %[[ADDPTR_23]], %[[SPLAT_24]] -// AMD: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[EXPAND_DIMS_27:.*]] = tt.expand_dims %{{.*}}#5 {axis = 1 : i32} -// AMD: %[[BROADCAST_28:.*]] = tt.broadcast %[[EXPAND_DIMS_27]] -// AMD: %[[MULI_29:.*]] = arith.muli %{{.*}}, %[[BROADCAST_28]] -// AMD: %[[ADDPTR_30:.*]] = tt.addptr %{{.*}}, %[[MULI_29]] -// AMD: %[[SPLAT_31:.*]] = tt.splat %[[CMPI_22]] -// AMD: %[[LOAD_32:.*]] = tt.load %[[ADDPTR_30]], %[[SPLAT_31]] -// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %{{.*}}#6 -// AMD: %[[IF_34:.*]] = scf.if %[[CMPI_21]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_26]], %[[LOCAL_LOAD_33]], %{{.*}}#0 -// AMD: scf.yield %[[DOT_45]] -// AMD: } else { -// AMD: scf.yield %{{.*}}#0 -// AMD: } -// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} -// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} -// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_25]], %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_37]], %{{.*}}, %{{.*}}] -// AMD: triton_gpu.local_store %[[LOAD_32]], %[[MEMDESC_SUBVIEW_39]] -// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_34]], %{{.*}}#0 -// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]] -// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]] -// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] -// AMD: scf.yield %[[DOT_45]] -// AMD: } else { -// AMD: scf.yield %[[SELECT_40]] -// AMD: } -// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] -// AMD: triton_gpu.local_dealloc %{{.*}} -// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: %[[ADDPTR_26:.*]] = tt.addptr %{{.*}}#1, %{{.*}} +// AMD: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_27]] +// AMD: %[[EXPAND_DIMS_29:.*]] = tt.expand_dims %{{.*}}#6 {axis = 1 : i32} +// AMD: %[[BROADCAST_30:.*]] = tt.broadcast %[[EXPAND_DIMS_29]] +// AMD: %[[MULI_31:.*]] = arith.muli %{{.*}}, %[[BROADCAST_30]] +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %{{.*}}, %[[MULI_31]] +// AMD: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_33]] +// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_25]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_40]] +// AMD: } +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] +// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: triton_gpu.local_dealloc %{{.*}} #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -1321,23 +1242,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: triton_gpu.local_store // AMD: scf.yield // AMD: triton_gpu.local_dealloc - -// AMD_PREFETCH-LABEL: tt.func public @nested_loops -// AMD_PREFETCH-NOT: triton_gpu.local_alloc -// AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_alloc -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.for -// AMD_PREFETCH: triton_gpu.local_store -// AMD_PREFETCH: tt.load -// AMD_PREFETCH: tt.dot -// AMD_PREFETCH: triton_gpu.local_load -// AMD_PREFETCH: scf.yield -// AMD_PREFETCH: triton_gpu.local_dealloc - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> @@ -1667,25 +1571,10 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, // AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: scf.for +// AMD: arith.select +// AMD: arith.addf // AMD: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] // AMD: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] -// AMD: arith.addf -// AMD: arith.select -// AMD: scf.yield - -// AMD_PREFETCH-LABEL: @masked_add_kernel -// AMD_PREFETCH: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> -// AMD_PREFETCH-COUNT-6: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] -// AMD_PREFETCH: scf.for -// AMD_PREFETCH: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] -// AMD_PREFETCH: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] -// AMD_PREFETCH: arith.addf -// AMD_PREFETCH: arith.select -// AMD_PREFETCH: tt.store -// AMD_PREFETCH: scf.yield -// AMD_PREFETCH: tt.store -// AMD_PREFETCH: tt.store -// AMD_PREFETCH: tt.store #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index dfd2aac21e9b..67dc3f2b640b 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -221,8 +221,7 @@ def make_ttgir(mod, metadata, options): "num_stages == 0. Now it will not happen anymore; " "please update to use num_stages == 2 for " "equivalent behavior in the past.") - prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1" - amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages, prefetch) + amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.insert_instruction_sched_hints(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 636743d305f9..d0ffdae28ed5 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -7,8 +7,7 @@ namespace mlir { -std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2, - int prefetch = 0); +std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2); std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 85604dcaca18..93345b0d6de4 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -18,10 +18,7 @@ def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir let options = [ Option<"numStages", "num_stages", "int32_t", /*default*/"2", - "Number of Pipeline stages">, - Option<"prefetch", "prefetch", - "int32_t", /*default*/"0", - "Enable prefetch from shared memory"> + "Number of Pipeline stages"> ]; } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index 85bcb8a0c7d0..54f13083a249 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -366,10 +366,12 @@ struct TritonAMDGPUReorderInstructionsPass moveUpTranspose(funcOp); - if (isPureMatmulProblem(funcOp)) - sinkSecondLoad(funcOp); - else + moveUpTranspose(funcOp); + + if (isPureMatmulProblem(funcOp)) { scheduleGlobalLoadLocalStore(funcOp); + sinkSecondLoad(funcOp); + } } } }; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index 1a4dd8227c73..3b4935026c3f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -19,7 +19,7 @@ // modulo schedule and an expander that rewrites the loop and emits a prologue // and epilogue. This pass first calls a helper that will pre-process the IR // to create stream operations and create a modulo schedule. Then we call the -// expander to generate the prologue and new loop and epilogue. +// expander to generate the prologue and new loop. //===----------------------------------------------------------------------===// #define GEN_PASS_CLASSES @@ -56,117 +56,42 @@ static Operation *streamPredication(RewriterBase &rewriter, Operation *op, namespace { -//===----------------------------------------------------------------------===// -// Software pipelining generally works by anchoring on global load ops in the -// main loop and rotating the loop to schedule global load ops for future loop -// iterations together with compute for the current iteration. In this way, we -// can 1) issue memory operations earlier to hide the latency and 2) break the -// strong dependency inside on loop iteration to give backends flexiblity to -// better interleave instructions for better instruction-level parallelism. -// -// This StreamPipeliner class creates the pipelining schedule and calls the -// PipelineExpander to rewrite the `scf.for` loop accordingly. A schedule -// consists of multiple stages, where ops from different stages can overlap -// executions because the dependencies are loop carried. -// -// The general flow of this process is: -// -// 1. The user provides a `num_stages` that specifies how many stages the -// pipeline will have. The number of stages must be larger than the distance -// from the first independent load to the compute in order to pipeline. -// 2. A schedule is created based on the distance between the global loads -// in the first stages and the compute that uses the loaded values in the -// last stage (num_stages - 1). Each operation will be clustered in the -// order to best overlap with other operations (see details below in the -// initSchedule method). -// 3. When the compute is a tt.dot, the scheduler will insert a shared -// memory allocation between the global load and tt.dot. The ttg.local_store -// will save the global load value to shared memory and the ttg.local_load -// will load the relevant tiles for the tt.dot. These operations will be -// scheduled according to various scheduling schemes outlined below in the -// initSchedule method (see details there). -// 4. Finally the schedule will be passed to the PipelineExpander to rewrite -// accordingly. The new implementation will consist of: -// a. Prologue: containing the ramp-up of num_stages-1 stages for -// iteratorions i=[0, num_stages-1). -// b. New loop: ordered by cluster and iterated on each operation by -// `i + (num_stages-op_stage)`. -// c. Epilogue: ramp-down of the last `num_stages-1` iterations for the -// ops in stages 1 to last_stage. This must consider that the loop -// bounds may be shorter than num_stages. In this case, the epilogue -// iterations must align with the prologue. -// +// Encapsulate stream pipelining +// For each `scf.for` create a StreamPipeliner manager. class StreamPipeliner { public: - StreamPipeliner(scf::ForOp _forOp, int _numStages, bool _prefetch) - : forOp(_forOp), prefetch(_prefetch), numStages(_numStages + prefetch), - schedule(numStages), + StreamPipeliner(scf::ForOp _forOp, int _numStages) + : forOp(_forOp), schedule(_numStages), numStages(_numStages), axisInfoAnalysis(forOp->getParentOfType()) { options.supportDynamicLoops = true; options.peelEpilogue = true; options.predicateFn = streamPredication; } - LogicalResult pipelineLoop(); - -private: - void initSchedule(int maxIndirectionLevel); - void computeLoadOpsToIndirectionLevelAndUse(); void assignMemoryLayouts(); - LogicalResult scheduleLoads(DenseSet &rootUsers); + void scheduleLoads(DenseSet &rootUsers); void scheduleDependencies(); void scheduleDistanceOneDependencies(); - void scheduleRemainingToLastStage(); + void scheduleRemainingToLastStage(tt::CoarseSchedule::Cluster afterPrologue); - LogicalResult preprocessLoopAndBuildSchedule(); + bool preprocessLoopAndBuildSchedule(); + bool pipelineLoop(); Value createAlloc(Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, unsigned numBuffers); - void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx); + void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx, + tt::CoarseSchedule::Cluster prefetchCluster); void createStreamOps(); - // Define categories of scheduling details per Operation types. - // The StreamPipeliner schedules 5 types of operations: - // 1. GLOBAL_LOAD: tt.load - // 2. LOCAL_STORE: ttg.local_store (created by the StreamPipeliner) - // 3. LOCAL_LOAD: ttg.local_load (created by the StreamPipeliner) - // 4. COMPUTE: ops that use the loaded data - // 5. TAIL: everything else in the loop - enum SchedType { - SCHED_GLOBAL_LOAD, - SCHED_LOCAL_STORE, - SCHED_LOCAL_LOAD, - SCHED_COMPUTE, - SCHED_TAIL - }; - - void scheduleOp(Operation *op, SchedType type, int stage = -1) { - if (stage < 0) - stage = config[type].stage; - schedule.insert(op, stage, config[type].cluster); - } - private: - // Data members scf::ForOp forOp; - - // User settings - bool prefetch; - int numStages; - - // Scheduling clusters tt::CoarseSchedule schedule; - - // ScheduleConfig lookup by SchedType to get the stage and cluster. - struct ScheduleConfig { - int stage; - tt::CoarseSchedule::Cluster cluster; - }; - SmallVector config; + int numStages; // Mapping and indirection level for each `tt.load` to its use. - SmallVector> loadOpToIndLevelAndUse; + llvm::SmallVector> + loadOpToIndLevelAndUse; struct LoadInfo { // Shared layout is used for loads feeding into dot ops. @@ -191,64 +116,9 @@ class StreamPipeliner { } // namespace -// Init Schedule Config based on settings and loop characteristics. -// Create clusters in order of ops in loop. This can interleave ops -// from different stages in the same cluster to achieve better backend -// scheduling. -// WARNING: Changing the order of schedule.clusters.newAtBack() calls -// can cause invalid schedules to be produced. -void StreamPipeliner::initSchedule(int maxIndirectionLevel) { - int lastStage = numStages - 1; - config.resize(SCHED_TAIL + 1); - - bool isMultibuf = numStages > (2 + maxIndirectionLevel); - if (prefetch) { - // Prefetch Schema cluster order and staging. - // for i in (...): - // local_stores: stage=i+1 - // global_loads: stage=i+2 - // compute: stage=i - // local_load: stage=i+1 - // tail: stage=i - config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; - auto cluster1 = schedule.clusters.newAtBack(); - config[SCHED_GLOBAL_LOAD] = {0, cluster1}; - config[SCHED_COMPUTE] = {lastStage, cluster1}; - config[SCHED_LOCAL_LOAD] = {lastStage - 1, schedule.clusters.newAtBack()}; - config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; - } else if (isMultibuf) { - // Streaming Schema cluster order and staging for multi-buffer. - // for i in (...): - // local_stores: stage=i+1 - // global_loads: stage=i+2 - // local_load: stage=i - // compute: stage=i - // tail: stage=i - config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; - auto cluster1 = schedule.clusters.newAtBack(); - config[SCHED_GLOBAL_LOAD] = {0, cluster1}; - config[SCHED_LOCAL_LOAD] = {lastStage, cluster1}; - config[SCHED_COMPUTE] = {lastStage, cluster1}; - config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; - } else { - // Streaming Schema cluster order and staging for single-buffer. - // for i in (...): - // global_loads: stage=i+1 - // local_load: stage=i - // compute: stage=i - // local_stores: stage=i+1 - // tail: stage=i - auto cluster0 = schedule.clusters.newAtBack(); - config[SCHED_GLOBAL_LOAD] = {0, cluster0}; - config[SCHED_LOCAL_LOAD] = {lastStage, schedule.clusters.newAtBack()}; - config[SCHED_COMPUTE] = {lastStage, cluster0}; - config[SCHED_LOCAL_STORE] = {lastStage - 1, schedule.clusters.newAtBack()}; - config[SCHED_TAIL] = {lastStage, schedule.clusters.newAtBack()}; - } -} - -void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, - Value extractIdx) { +void StreamPipeliner::createStreamCopy( + tt::LoadOp loadOp, Value alloc, Value extractIdx, + tt::CoarseSchedule::Cluster prefetchCluster) { OpBuilder builder(forOp); Value zero = builder.create(forOp.getLoc(), 0, 32); // Replace the load with insert/extract slice. @@ -256,7 +126,6 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, Location loc = loadOp.getLoc(); Value src = loadOp.getPtr(); Value mask = loadOp.getMask(); - Value other = loadOp.getOther(); tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); @@ -276,6 +145,8 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); // Clean up old local caches. SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { @@ -287,18 +158,17 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, for (auto alloc : allocsToErase) alloc.erase(); - // Prefetch load ahead of the dot stage if is used by the dot. - auto storeOp = - builder.create(loc, copy->getResult(0), viewLoad); - scheduleOp(viewLoad, SCHED_LOCAL_STORE); - scheduleOp(storeOp, SCHED_LOCAL_STORE); - - // Create local load auto sharedLoad = builder.create(loc, loadOp.getType(), viewLoad); - Value result = sharedLoad.getResult(); - if (prefetch) - scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values. + Value other = loadOp.getOther(); + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + result = select->getResults(); + } // If the currently processed `LoadOp` is labeled with an index regarding // to which `DotOp` operand the corresponding data belongs to, then label the @@ -309,13 +179,14 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr); } - loadOp->replaceAllUsesWith(ValueRange{result}); + loadOp->replaceAllUsesWith(result); - if (prefetch && result.hasOneUse()) { - if (auto cvt = dyn_cast(*result.getUsers().begin())) - scheduleOp(cvt, SCHED_LOCAL_LOAD); + // Prefetch load ahead of the dot stage if is used by the dot. + if (loadToInfo[loadOp].usedByDot) { + assert(numStages >= 2 && "requires num_stages=2 at least"); + schedule.insert(storeOp, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); } - loadOp.erase(); } @@ -477,7 +348,7 @@ void StreamPipeliner::assignMemoryLayouts() { } } -LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { +void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { // Get all loads that are (transitively) used by dot ops and their distance // to the dot op. computeLoadOpsToIndirectionLevelAndUse(); @@ -490,12 +361,12 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { } }); if (loadOpToIndLevelAndUse.empty()) - return failure(); + return; // Check which loads are good for pipelining, and assign them memory layouts. assignMemoryLayouts(); if (loadToInfo.empty()) - return failure(); + return; // Filter out load ops that cannot be pipelined. int resize = 0; @@ -511,12 +382,6 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) maxIndirectionLevel = std::max(maxIndirectionLevel, dist); - LDBG("maxIndirectionLevel = " << maxIndirectionLevel); - if (maxIndirectionLevel >= numStages) - return failure(); - - initSchedule(maxIndirectionLevel); - // The stage gap between chained loads--this allows us to "spread" loads // with a non-one step in case the number of stages given by the user is // large. @@ -526,18 +391,24 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { LDBG("stagesBetweenLoads = " << stagesBetweenLoads); // Put the root uses of the loads in the last stage. + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). if (!isa(use)) { - scheduleOp(use, SCHED_COMPUTE); + schedule.insert(use, numStages - 1, rootUsersCluster); rootUsers.insert(use); } } + // Create a cluster for load ops at each indirection level. + SmallVector loadsClusters; + for (int i = 0; i <= maxIndirectionLevel; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } // Assign stages to the loads. for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; - scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); + schedule.insert(loadOp, stage, loadsClusters[indLevel]); } // Calculate distance from the load to the use. @@ -553,8 +424,6 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { LDBG(" usedByDot: " << info.usedByDot); } }); - - return success(); } // Add dependencies of anchor ops to the coarse schedule. Schedule them to @@ -621,23 +490,22 @@ void StreamPipeliner::scheduleDistanceOneDependencies() { } } -void StreamPipeliner::scheduleRemainingToLastStage() { - int lastStage = numStages - 1; +void StreamPipeliner::scheduleRemainingToLastStage( + tt::CoarseSchedule::Cluster afterPrologue) { // Assign the rest of the ops to the last stage. // Take care of the ordering of the ops - uses cannot be scheduled to the // cluster before the definition. DenseMap opToCluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) { - auto schedType = isa(op) ? SCHED_COMPUTE : SCHED_TAIL; - opToCluster[&op] = config[schedType].cluster; + opToCluster[&op] = afterPrologue; } } SmallVector queue; for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { // We really only care about the producers from the last stage. // Others will be scheduled before these ops anyway. - if (stage == lastStage) { + if (stage == numStages - 1) { queue.push_back(op); } } @@ -655,7 +523,7 @@ void StreamPipeliner::scheduleRemainingToLastStage() { } } for (auto [op, cluster] : opToCluster) { - schedule.insert(op, lastStage, cluster); + schedule.insert(op, numStages - 1, cluster); } } @@ -672,10 +540,8 @@ Value StreamPipeliner::createAlloc(Operation *loadOp, Type memdescType = tt::MemDescType::get(bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, /*mutableMemory=*/true); - auto alloc = - builder.create(loadOp->getLoc(), memdescType, Value()); - sharedMemAllocs.push_back(alloc); - return alloc; + return builder.create(loadOp->getLoc(), memdescType, + Value()); } // Convert load ops into shared memory allocation loads and apply @@ -683,20 +549,19 @@ Value StreamPipeliner::createAlloc(Operation *loadOp, void StreamPipeliner::createStreamOps() { // Calculate the number of buffers needed for each load. // TODO: Use the precise number of buffers needed by the particular load. - int maxNumBuffers = -1; - for (auto &[_, info] : loadToInfo) { - int sharedBuffers = info.distToUse - (info.usedByDot ? prefetch : 0); - maxNumBuffers = std::max(maxNumBuffers, sharedBuffers); - } - LDBG("deduced max shared memory buffer number = " << maxNumBuffers); + int numBuffers = -1; + for (auto &[_, info] : loadToInfo) + numBuffers = std::max(numBuffers, info.distToUse); + LDBG("deduced shared memory buffer number = " << numBuffers); SmallVector> loadToAllocs; for (auto &[loadOp, info] : loadToInfo) { if (!info.sharedEncoding) continue; - Value alloc = createAlloc(loadOp, info.sharedEncoding, maxNumBuffers); + Value alloc = createAlloc(loadOp, info.sharedEncoding, numBuffers); assert(alloc && "Failed to create alloc for the async load."); + sharedMemAllocs.push_back(alloc); loadToAllocs.emplace_back(loadOp, alloc); } @@ -709,7 +574,7 @@ void StreamPipeliner::createStreamOps() { Value one = builder.create(loc, 1, 32); Value extractIdx = minusOne; Value numBuffersVal = - builder.create(loc, maxNumBuffers, 32); + builder.create(loc, numBuffers, 32); unsigned newOperandIndex = forOp.getBody()->getNumArguments(); // Patch the loop to add the new loop carried dependencies. @@ -728,23 +593,24 @@ void StreamPipeliner::createStreamOps() { extractIdx, numBuffersVal); extractIdx = builder.create(loc, cndExt, extractIdx, zero); - // Create stream copies. + // Create a cluster for prefetching global reads for the dot. + tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + for (auto &[op, alloc] : loadToAllocs) { if (auto loadOp = dyn_cast(op)) - createStreamCopy(loadOp, alloc, extractIdx); + createStreamCopy(loadOp, alloc, extractIdx, prefetchCluster); } // Patch the yield with the updated counters. appendToForOpYield(forOp, {extractIdx}); } -LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { +bool StreamPipeliner::preprocessLoopAndBuildSchedule() { // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; - if (failed(scheduleLoads(rootUsers))) - return failure(); + scheduleLoads(rootUsers); if (loadToInfo.empty()) - return failure(); + return false; LLVM_DEBUG({ LDBG("Coarse schedule loads only:"); @@ -754,6 +620,13 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Convert the loads into shared memory allocations and loads from them. createStreamOps(); + LLVM_DEBUG({ + LDBG("Coarse schedule with stream loads:"); + schedule.dump(); + }); + + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + scheduleDependencies(); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); @@ -766,7 +639,7 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { schedule.dump(); }); - scheduleRemainingToLastStage(); + scheduleRemainingToLastStage(afterPrologue); LLVM_DEBUG({ LDBG("Final coarse schedule:"); schedule.dump(); @@ -789,18 +662,7 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Explicitly deallocate created allocations. for (auto alloc : sharedMemAllocs) builder.create(forOp.getLoc(), alloc); - - return success(); -} - -LogicalResult StreamPipeliner::pipelineLoop() { - if (failed(preprocessLoopAndBuildSchedule())) - return failure(); - LDBG("Loop before sending to expander:\n" << *forOp); - - IRRewriter rewriter(forOp->getContext()); - rewriter.setInsertionPoint(forOp); - return tt::pipelineForLoop(rewriter, forOp, options); + return true; } // Return true if the preconditions for pipelining the loop are met. @@ -820,6 +682,19 @@ static bool checkPrecondition(scf::ForOp forOp) { return !forOp->walk(hasNestedLoopInside).wasInterrupted(); } +bool StreamPipeliner::pipelineLoop() { + if (!checkPrecondition(forOp)) + return false; + + if (!preprocessLoopAndBuildSchedule()) + return false; + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return succeeded(tt::pipelineForLoop(rewriter, forOp, options)); +} + namespace { // Go through a single use chain to get the result of the target op after all // unary ops - e.g., `convert_layout`, `fp_to_fp`, etc. @@ -858,10 +733,7 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) { struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { PipelinePass() = default; - PipelinePass(int32_t numStages, int32_t prefetch) { - this->numStages = numStages; - this->prefetch = prefetch; - } + PipelinePass(int32_t numStages) { this->numStages = numStages; } void runOnOperation() override { SmallVector loops; @@ -873,11 +745,8 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { }); for (scf::ForOp forOp : loops) { - if (!checkPrecondition(forOp)) - continue; - StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp), prefetch); - if (failed(sp.pipelineLoop())) - continue; + StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp)); + sp.pipelineLoop(); } } @@ -893,6 +762,6 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { } // anonymous namespace std::unique_ptr -mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages, int prefetch) { - return std::make_unique(numStages, prefetch); +mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages) { + return std::make_unique(numStages); } diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a9bd3e9b7fb7..9eab3771263e 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -72,8 +72,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUConvertToBufferOpsPass); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); - ADD_PASS_WRAPPER_2("add_stream_pipelinev2", - mlir::createTritonAMDGPUStreamPipelineV2Pass, int, int); + ADD_PASS_WRAPPER_1("add_stream_pipelinev2", + mlir::createTritonAMDGPUStreamPipelineV2Pass, int); } void addControlConstant(llvm::Module *module, const char *name,