Skip to content

Commit

Permalink
add maxtripcount and logic to snap lattices to max for such loops
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Feb 5, 2025
1 parent 961dc72 commit 4a9734d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
67 changes: 67 additions & 0 deletions test/TritonGPU/amd/amd-convert-buffer-ops-range-analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,73 @@ module attributes {"ttg.num-warps" = 4 : i32} {

// -----

// CHECK-LABEL: tt.func @forNestedOverMaxTripCount(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [0, 0]} dense<0> : tensor<1024xi64>
// CHECK: %[[VAL_3:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [1024, 1024]} 1024 : i32
// CHECK: %[[VAL_4:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [0, 0]} 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [128, 128]} 128 : index
// CHECK: %[[VAL_6:.*]] = arith.constant {__amdgpuconvertbufferops.output_range = [1, 1]} 1 : index
// CHECK: %[[VAL_7:.*]] = tt.get_program_id x {__amdgpuconvertbufferops.output_range = [0, 2048]} : i32
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] {__amdgpuconvertbufferops.output_range = [0, 2097152]} : i32
// CHECK: %[[VAL_9:.*]] = tt.make_range {__amdgpuconvertbufferops.output_range = [0, 1024], end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_15:.*]]:3 = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_14]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_9]] {__amdgpuconvertbufferops.output_range = [0, 1024]} : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] {__amdgpuconvertbufferops.output_range = [-9223372036854775808, 9223372036854775807]} : tensor<1024xi64>
// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32>
// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_26]] : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_27:.*]]#0, %[[VAL_27]]#1, %[[VAL_27]]#2 : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_29:.*]]#0, %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] {__amdgpuconvertbufferops.output_range = [0, 1024]} : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]]#1 {__amdgpuconvertbufferops.output_range = [-9223372036854775808, 9223372036854775807]} : tensor<1024xi64>
// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_28]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_34]] : tensor<1024xf32>
// CHECK: }

module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @forNestedOverMaxTripCount(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%cst = arith.constant dense<0> : tensor<1024xi64>
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
%10:3 = scf.for %arg6 = %c0 to %c128 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
%11 = tt.addptr %arg7, %1 : !tt.ptr<f32>, i32
%12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64>
%13 = arith.addi %12, %arg8 : tensor<1024xi64>
%14 = tt.splat %11 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%15 = tt.addptr %14, %13 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
%16 = tt.load %15 : tensor<1024x!tt.ptr<f32>>
%17 = arith.addf %16, %arg9 : tensor<1024xf32>
scf.yield %11, %13, %17 : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
}
scf.yield %10#0, %10#1, %10#2 : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
}
%4 = tt.addptr %3#0, %1 : !tt.ptr<f32>, i32
%5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64>
%6 = arith.addi %5, %3#1 : tensor<1024xi64>
%7 = tt.splat %4 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%8 = tt.addptr %7, %6 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
%9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
tt.return %9 : tensor<1024xf32>
}
}

// -----

// CHECK-LABEL: tt.func @ifOp(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: tensor<1024xf32>, %[[VAL_2:.*]]: i1) -> tensor<1024xf32> {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
Expand Down
15 changes: 11 additions & 4 deletions third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace tt = mlir::triton;

namespace {

constexpr int64_t kDefaultMaxTripCount = 0;
constexpr int64_t kDefaultMaxTripCount = 1024;
const std::string kConvertBufferOpsPrefix = "__amdgpuconvertbufferops.";
const std::string kOutputRange = kConvertBufferOpsPrefix + "output_range";

Expand Down Expand Up @@ -237,11 +237,18 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
for (auto [oper, argLat] :
llvm::zip(*operands, ArrayRef(lattices).drop_front(firstIndex))) {
std::pair loopArgLat = {loop, argLat};
// If we've "run the loop" #tripcount times, stop propagating.
if (loop && loopVisits[loopArgLat] >= loopTripCounts[loop])
continue;
const dataflow::IntegerValueRangeLattice *rhs =
getLatticeElementFor(point, oper);
ChangeResult changed = argLat->join(*rhs);
ChangeResult changed;
if (loop && loopTripCounts[loop] > kDefaultMaxTripCount) {
// If the loop's tripcount is too large, "snap" arg lattices to max
// range (which will "snap" body values to max range as well).
changed = argLat->join(IntegerValueRange::getMaxRange(oper));
} else {
// Else, propagate pred operands.
changed = argLat->join(*getLatticeElementFor(point, oper));
}
propagateIfChanged(argLat, changed);
if (loop && changed == ChangeResult::Change)
++loopVisits[loopArgLat];
Expand Down

0 comments on commit 4a9734d

Please sign in to comment.