Skip to content

Commit

Permalink
Implement IR support for LinearLayouts
Browse files Browse the repository at this point in the history
We also exercise this in scale_dot, where we enable support for warps of
arbitrary shape (before we just allowed `[num_warps, 1]`).

With this infra in place, it should be rather easy to move from the
legacy layouts to using LLs to represent all of our layouts.

Something I'm concerned about is the amount of recomputation that
happens when calling methods like `getSizePerThread` and the like, where
we keep recomputing the result. There might be an optimisation
opportunity here where we cache the result of all these functions.

We choose the IR representation of an LL via its canonical form + a
`repOrder` for several reasons:
- It's generally more compact
- It's easier to CSE, so it's easier to see when two layouts are in fact
  the same.
- A technical reason: the `toLinearLayout` function returns a tensor
  with dimensions `dim0, ..., dim<rank-1>`, in other words, it "forgets"
  the repetition order. Without the repetition order, we cannot recover
  the tile size of the argument. In particular, we cannot recover
  `getSizePerThread`. There is an argument to be made about whether
  `getSizePerThread` is useful on its own, or whether it is
  `getElemsPerThread` the real useful abstraction here, but for now, we
  keep both for BC.
  • Loading branch information
lezcano committed Nov 15, 2024
1 parent 7f06338 commit 2f7e8f7
Show file tree
Hide file tree
Showing 12 changed files with 792 additions and 204 deletions.
58 changes: 58 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,64 @@ triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs);

// For each output dimension d, ensure that the layout's output size (i.e., its
// codomain) does not exceed shape[d]. Do this without changing the size of the
// layout's inputs (i.e., leave its domain unchanged).
//
// This function is invariant to the order of the layout's input and output
// dimensions.
//
// We achieve this by setting the largest value in each output dimension d to 0
// because bases that map to a location larger than shape[d]
// effectively duplicate along that dimension. For example, consider a layout
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
// shrink the output dimension size to 8:
//
// L(register=1) = 8
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 16
//
// In the first step, we shrink the output dimension size to 16 by setting
// L(lane=2) to 0:
//
// L(register=1) = 8
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 0
//
// This means that lane=2 has the same data as lane=0.
//
// Now the output dimension of this layout has a size of 16, which is still
// larger than 8. We find the current largest value in the output dimension,
// which is L(register=1) = 8, and we set L(register=1) to 0:
//
// L(register=1) = 0
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 0
//
// Now the output dimension of this layout has a size of 8, which is the desired
// size. Note that this method works only because the bases are powers of two.
// It is unclear what to do when they are not.
LinearLayout ensureLayoutNotLargerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);

// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no
// smaller than shape[d]. Do this by increasing the size of the layout's inputs
// along its most-minor dimension ("register" for register layouts, "offset" for
// shared layouts).
//
// This function is invariant to the order of the layout's input dimensions, but
// it cares about the order of the output dims, which should be minor-to-major.
LinearLayout ensureLayoutNotSmallerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);

// Dump information about which threads/registers contain each of the tensor
// elements.
void dumpLayout(RankedTensorType tensorType);
Expand Down
35 changes: 33 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
code extraBaseClassDeclaration = [{
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}

Expand Down Expand Up @@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -571,6 +569,39 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];
}

//===----------------------------------------------------------------------===//
// Linear Layout Encoding
//===----------------------------------------------------------------------===//

def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
let mnemonic = "linear";

let description = [{
See the docs in LinearLayout.h for the definition of linear layouts.
}];

let parameters = (
ins
"LinearLayout":$linearLayout,
ArrayRefParameter<"unsigned">:$repOrder__
);

let builders = [
AttrBuilder<(ins "const LinearLayout&":$linearLayout, "ArrayRef<unsigned>":$repOrder)>
];


let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() const;
SmallVector<unsigned> getOrder() const;
}];

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}


//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -432,6 +433,7 @@ class LinearLayout {
// (e.g. by reshaping) then the order doesn't really affect anything.
auto getInDimNames() const { return llvm::make_first_range(bases); }
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }

// Gets the position that this outDim occupies in getOutDimNames(). Asserts
// if the dim is not present.
Expand Down Expand Up @@ -693,6 +695,7 @@ class LinearLayout {
return !(lhs == rhs);
}
bool equalIgnoringOutDimSizes(const LinearLayout &other) const;
friend size_t hash_value(const LinearLayout &layout);

private:
// Factory function that gracefully fails rather than asserts if the layout is
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (isa<BlockedEncodingAttr>(layout)) {
return true;
}
if (isa<LinearEncodingAttr>(layout)) {
return true;
}
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
return layoutIsOK(slice.getParent());
}
Expand Down
18 changes: 7 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(RankedTensorType type) {
auto layout = type.getEncoding();
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
static bool isSupportedDotOpLayout(Attribute layout) {
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto kWidth = dot.getKWidth();
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
// - kWidth == 8
// - kWidth == 4, bitwidth = 32
// - fp8 with kWidth == 4 and warpSize != {numWarps, 1} or {1, numWarps}
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
bool legacyLoweringIsBuggy =
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
bool legacyLoweringIsBuggy = dot.getKWidth() >= 4;
return legacyLoweringIsBuggy && mma.isAmpere();
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
Expand All @@ -165,9 +161,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(dstLayout))) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -207,7 +203,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand Down
Loading

0 comments on commit 2f7e8f7

Please sign in to comment.