From 4a9734d56a647215a042d0f2ca6d11c7bd6ec516 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 5 Feb 2025 15:13:02 -0500 Subject: [PATCH] add maxtripcount and logic to snap lattices to max for such loops --- ...amd-convert-buffer-ops-range-analysis.mlir | 67 +++++++++++++++++++ .../ConvertToBufferOps.cpp | 15 +++-- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir index ee56b3d496122..a8a804fb60690 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir @@ -315,6 +315,73 @@ module attributes {"ttg.num-warps" = 4 : i32} { // ----- +// CHECK-LABEL: tt.func @forNestedOverMaxTripCount( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [0, 0]} dense<0> : tensor<1024xi64> +// CHECK: %[[VAL_3:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [1024, 1024]} 1024 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [0, 0]} 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [128, 128]} 128 : index +// CHECK: %[[VAL_6:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [1, 1]} 1 : index +// CHECK: %[[VAL_7:.*]] = tt.get_program_id x {__amdgpuconvertbufferops.output_range = [0, 2048]} : i32 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] {__amdgpuconvertbufferops.output_range = [0, 2097152]} : i32 +// CHECK: %[[VAL_9:.*]] = tt.make_range {__amdgpuconvertbufferops.output_range = [0, 1024], end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_15:.*]]:3 = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_14]]) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { +// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_9]] {__amdgpuconvertbufferops.output_range = [0, 1024]} : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] {__amdgpuconvertbufferops.output_range = [-9223372036854775808, 9223372036854775807]} : tensor<1024xi64> +// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr> +// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32> +// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_26]] : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_27:.*]]#0, %[[VAL_27]]#1, %[[VAL_27]]#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> +// CHECK: } +// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_29:.*]]#0, %[[VAL_8]] : !tt.ptr, i32 +// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] {__amdgpuconvertbufferops.output_range = [0, 1024]} : tensor<1024xi32> to tensor<1024xi64> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]]#1 {__amdgpuconvertbufferops.output_range = [-9223372036854775808, 9223372036854775807]} : tensor<1024xi64> +// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_28]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr>, tensor<1024xi64> +// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_34]] : tensor<1024xf32> +// CHECK: } + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @forNestedOverMaxTripCount(%arg0: !tt.ptr, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %cst = arith.constant dense<0> : tensor<1024xi64> + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %10:3 = scf.for %arg6 = %c0 to %c128 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { + %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 + %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %13 = arith.addi %12, %arg8 : tensor<1024xi64> + %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %16 = tt.load %15 : tensor<1024x!tt.ptr> + %17 = arith.addf %16, %arg9 : tensor<1024xf32> + scf.yield %11, %13, %17 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> + } + %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 + %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> + %6 = arith.addi %5, %3#1 : tensor<1024xi64> + %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr>, tensor<1024xi64> + %9 = tt.load %8 : tensor<1024x!tt.ptr> + tt.return %9 : tensor<1024xf32> + } +} + +// ----- + // CHECK-LABEL: tt.func @ifOp( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, %[[VAL_1:.*]]: tensor<1024xf32>, %[[VAL_2:.*]]: i1) -> tensor<1024xf32> { // CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32 diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index ab5b278326c2c..b08c0f33fb27f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -39,7 +39,7 @@ namespace tt = mlir::triton; namespace { -constexpr int64_t kDefaultMaxTripCount = 0; +constexpr int64_t kDefaultMaxTripCount = 1024; const std::string kConvertBufferOpsPrefix = "__amdgpuconvertbufferops."; const std::string kOutputRange = kConvertBufferOpsPrefix + "output_range"; @@ -237,11 +237,18 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { for (auto [oper, argLat] : llvm::zip(*operands, ArrayRef(lattices).drop_front(firstIndex))) { std::pair loopArgLat = {loop, argLat}; + // If we've "run the loop" #tripcount times, stop propagating. if (loop && loopVisits[loopArgLat] >= loopTripCounts[loop]) continue; - const dataflow::IntegerValueRangeLattice *rhs = - getLatticeElementFor(point, oper); - ChangeResult changed = argLat->join(*rhs); + ChangeResult changed; + if (loop && loopTripCounts[loop] > kDefaultMaxTripCount) { + // If the loop's tripcount is too large, "snap" arg lattices to max + // range (which will "snap" body values to max range as well). + changed = argLat->join(IntegerValueRange::getMaxRange(oper)); + } else { + // Else, propagate pred operands. + changed = argLat->join(*getLatticeElementFor(point, oper)); + } propagateIfChanged(argLat, changed); if (loop && changed == ChangeResult::Change) ++loopVisits[loopArgLat];