Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 755d416 #3058

Merged
merged 5 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/triton/Tools/LayoutUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef TRITON_TOOLS_LAYOUTUTILS_H
#define TRITON_TOOLS_LAYOUTUTILS_H

#include "triton/Tools/LinearLayout.h"

namespace mlir::triton {
// Is the sublayout defined from dimNames to dimNames the identity?
// In particular, is the input and output size in these dimensions
// the same, and are the bases the identity?
bool squareSublayoutIsIdentity(const LinearLayout &ll,
ArrayRef<StringAttr> dimNames);
} // namespace mlir::triton

#endif // TRITON_TOOLS_LAYOUTUTILS_H
5 changes: 0 additions & 5 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -611,11 +611,6 @@ class LinearLayout {
bool sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const;

// Is the sublayout defined from dimNames to dimNames the identity?
// In particular, is the input and output size in these dimensions
// the same, and are the bases the identity?
bool squareSublayoutIsIdentity(ArrayRef<StringAttr> dimNames) const;

// Computes and returns L(x, y, z).
//
// If you want to apply the layout to mlir Values instead of integers, that
Expand Down
7 changes: 3 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,14 @@ bool ReduceOpHelper::isSupportedLayout() {
}

auto srcLayout = getSrcLayout();
if (isa<BlockedEncodingAttr>(srcLayout)) {
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
srcLayout)) {
return true;
}

if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
return mmaLayout.supportReduction();
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
return true;
}
return false;
}

Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,

SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = -1;
int nonZeroIdx = 0;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
Expand All @@ -1482,7 +1482,6 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
assert(nonZeroIdx != -1);
ret[nonZeroIdx] *= 2;
}
}
Expand Down Expand Up @@ -1633,7 +1632,8 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto ll = *toLinearLayout(shape);
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
/*skipBroadcast=*/false);
}

// Start of Selection
Expand Down
23 changes: 22 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,27 @@ namespace gpu {

namespace {

bool hasGpuBarriers(scf::ForOp forOp) {
WalkResult result = forOp.walk(
[&](mlir::gpu::BarrierOp barrier) { return WalkResult::interrupt(); });
return result.wasInterrupted();
}

// Return true if the preconditions for pipelining the loop are met.
bool isSafeToPipeline(scf::ForOp forOp) {
// Skip loop with distance > 1 for now.
// TODO: relax the constraint in the expander.
if (loopHasDistGreaterThanOne(forOp))
return false;
// Don't pipeline outer loops.
if (isOuterLoop(forOp))
return false;
// Skip loops with barriers.
if (hasGpuBarriers(forOp))
return false;
return true;
}

bool hasLatenciesAssigned(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency) {
for (auto &op : forOp.getBody()->without_terminator()) {
Expand Down Expand Up @@ -261,7 +282,7 @@ void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule,

void scheduleLoop(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency) {
if (!hasLatenciesAssigned(forOp, opLatency))
if (!hasLatenciesAssigned(forOp, opLatency) || !isSafeToPipeline(forOp))
return;
// Based on the latencies, schedule the key ops to the stages.
CoarseSchedule schedule = scheduleKeyOps(forOp, opLatency);
Expand Down
14 changes: 0 additions & 14 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,6 @@ namespace gpu {
#define GEN_PASS_DEF_TRITONGPUPIPELINE
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

// Return true if the preconditions for pipelining the loop are met.
static bool preCondition(scf::ForOp forOp) {
// Skip loop with distance > 1 for now.
// TODO: relax the constraint in the expander.
if (loopHasDistGreaterThanOne(forOp))
return false;
// Don't pipeline outer loops.
if (isOuterLoop(forOp))
return false;
return true;
}

static void tryAndPipelineOuterLoop(scf::ForOp forOp) {
mlir::triton::PipeliningOption options;
bool foundSchedule = false;
Expand All @@ -60,8 +48,6 @@ static void tryAndPipelineOuterLoop(scf::ForOp forOp) {

static bool pipelineLoop(scf::ForOp forOp, int numStages) {
mlir::triton::PipeliningOption options;
if (!preCondition(forOp))
return false;

bool foundSchedule = false;
foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options);
Expand Down
1 change: 1 addition & 0 deletions lib/Tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_triton_library(TritonTools
LayoutUtils.cpp
LinearLayout.cpp

DEPENDS
Expand Down
32 changes: 32 additions & 0 deletions lib/Tools/LayoutUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "triton/Tools/LayoutUtils.h"

namespace mlir::triton {

bool squareSublayoutIsIdentity(const LinearLayout &ll,
ArrayRef<StringAttr> dimNames) {
// The empty layout is the identity
if (dimNames.size() == 0) {
return true;
}
// Check that the input-output sizes are the same
LinearLayout sl = ll.sublayout(dimNames, dimNames);
for (StringAttr dim : dimNames) {
if (ll.getInDimSize(dim) != ll.getOutDimSize(dim)) {
return false;
}
}
// Once the inputs and output dimensions are the same, we can just check
// that the basis for the single remaining dimension is the identity.
sl = sl.flattenIns().flattenOuts();
int b = 0;
const auto &inDimBases = sl.getBases().begin()->second;
for (auto basis : inDimBases) {
if (basis[0] != (1 << b)) {
return false;
}
b++;
}
return true;
}

} // namespace mlir::triton
30 changes: 2 additions & 28 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "mlir/IR/BuiltinAttributes.h"
#include "third_party/f2reduce/f2reduce.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
Expand Down Expand Up @@ -651,7 +652,7 @@ bool LinearLayout::isTrivialOver(ArrayRef<StringAttr> dimNames) const {
// We can quotient out dimNames iff they don't affect the remainingInDimNames
// in the result. In other words, we want to check that B is zero, and C is
// zero, and D is the identity
return squareSublayoutIsIdentity(dimNames) &&
return squareSublayoutIsIdentity(*this, dimNames) &&
sublayoutIsZero(remainingInDimNames, dimNames) &&
sublayoutIsZero(dimNames, remainingOutDimNames);
}
Expand Down Expand Up @@ -730,33 +731,6 @@ bool LinearLayout::sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
return true;
}

bool LinearLayout::squareSublayoutIsIdentity(
ArrayRef<StringAttr> dimNames) const {
// The empty layout is the identity
if (dimNames.size() == 0) {
return true;
}
// Check that the input-output sizes are the same
LinearLayout sl = sublayout(dimNames, dimNames);
for (StringAttr dim : dimNames) {
if (getInDimSize(dim) != getOutDimSize(dim)) {
return false;
}
}
// Once the inputs and output dimensions are the same, we can just check
// that the basis for the single remaining dimension is the identity.
sl = sl.flattenIns().flattenOuts();
int b = 0;
const auto &inDimBases = sl.bases.begin()->second;
for (auto basis : inDimBases) {
if (basis[0] != (1 << b)) {
return false;
}
b++;
}
return true;
}

SmallVector<std::pair<StringAttr, int32_t>>
LinearLayout::apply(ArrayRef<std::pair<StringAttr, int32_t>> ins) const {
assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames());
Expand Down
70 changes: 70 additions & 0 deletions test/Conversion/reduce_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s

#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: @reduce_linear_layout
tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> {
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1
// CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2
// CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3

// The layout looks lke
// [[ T0:0, T32:0, T0:1, T32:1, ...
// [ T4:0, T36:0, T4:1, T36:1, ...
// [ T0:2, T32:2, T0:3, T32:3, ...
// [ T4:2, T36:2, T4:3, T36:3,
// ...
//
// A reduction along axis=0 consists of adding registers (0, 2) and (1, 3)
// before shuffling.
//
// Columns along axis=0 are contained within a warp, so reduction arcoss warps
// is not needed.

// Reduce within threads
// CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]]
// CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]]

// Reduce within warp.
// CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31)
// CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]]
// CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31)
// CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]]
// CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31)
// CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]]
// CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31)
// CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]]

// CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31)
// CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]]
// CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31)
// CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]]
// CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31)
// CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]]
// CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31)
// CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]]

// CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0
// CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1

%0 = "tt.reduce"(%arg0) ({
^bb0(%arg1: i32, %arg2: i32):
%1 = arith.addi %arg1, %arg2 : i32
tt.reduce.return %1 : i32
}) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>

// CHECK-NEXT: ret { i32, i32 } [[DST1]]
tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
}

tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) {
%0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
%1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)>
llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr
tt.return
}

}
25 changes: 25 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2092,3 +2092,28 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m
}

}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}>
#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: expand_dims_linear_layout
tt.func private @expand_dims_linear_layout() -> tensor<1x4xi32, #linear> {
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>>
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x4xi32, #linear>
// CHECK: return %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
tt.return %1 : tensor<1x4xi32, #linear>
}

// CHECK-LABEL: reshape_linear_layout_broadcasting
tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #linear>) -> tensor<32x4x1xbf16, #blocked> {
// CHECK-COUNT-16: extractvalue
// CHECK-COUNT-16: insertvalue
%0 = tt.reshape %arg0 : tensor<32x4xbf16, #linear> -> tensor<32x4x1xbf16, #blocked>
tt.return %0 : tensor<32x4x1xbf16, #blocked>
}

}
23 changes: 23 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2829,3 +2829,26 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) {
}

}

// -----

#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}>
#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: reduce_linear_layouts
tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> {
// CHECK-NOT: convert_layout
%0 = ttg.convert_layout %arg0 : tensor<32x32xi32, #linear> -> tensor<32x32xi32, #blocked>
// CHECK-NEXT: tt.reduce
%1 = "tt.reduce" (%0) ({
^bb0(%arg1: i32, %arg2: i32):
tt.reduce.return %arg1 : i32
// CHECK: (tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>
}) {axis = 1 : i32} : (tensor<32x32xi32, #blocked>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%2 = ttg.convert_layout %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
tt.return %2 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
}

}
15 changes: 15 additions & 0 deletions test/TritonGPU/pipeline-schedule-loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -334,4 +334,19 @@ tt.func @indirect_load(%lb : index, %ub : index, %step : index,
}
tt.return %loop#3: tensor<128x128xf32, #C>
}

// Verify that we don't schedule/pipeline loops with gpu.barrier
// CHECK-LABEL: @gpu_barrier
tt.func @gpu_barrier(%lb : index, %ub : index, %step : index,
%a_ptr_init : tensor<128x32x!tt.ptr<f16>, #A>) -> tensor<128x32xf16, #A> {
%init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A>
%loop = scf.for %iv = %lb to %ub step %step iter_args(%acc = %init) -> (tensor<128x32xf16, #A>) {
// CHECK-NOT: loop.cluster
%a = tt.load %a_ptr_init {tt.latency = 2 : i32} : tensor<128x32x!tt.ptr<f16>, #A>
%res = arith.addf %acc, %a : tensor<128x32xf16, #A>
gpu.barrier
scf.yield %res : tensor<128x32xf16, #A>
}
tt.return %loop#0 : tensor<128x32xf16, #A>
}
}
2 changes: 1 addition & 1 deletion unittest/Tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_triton_ut(
NAME LinearLayout
SRCS LinearLayoutTest.cpp
SRCS LayoutUtilsTest.cpp LinearLayoutTest.cpp
LIBS TritonTools
)
Loading