-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The `TritonTilingExt` dialect leverages linalg's `TilingInterface` to add tiling & fusion support for operators that cannot be represented in linalg Much of the tiling implementation is borrowed from linalg's tiling code. As an example, I have added a barebone version of `cumsum` that represents the cumulative sum of the inner-most dimension of a tensor. Other backends can then lower the op to lower-level implementation as needed after applying tiling & fusion.
- Loading branch information
1 parent
6fa7ce3
commit 230c38b
Showing
19 changed files
with
1,282 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(Conversion) | ||
add_subdirectory(Conversion) | ||
add_subdirectory(Dialect) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_subdirectory(TritonTilingExt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_subdirectory(IR) |
11 changes: 11 additions & 0 deletions
11
include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
set(LLVM_TARGET_DEFINITIONS TritonTilingExtOps.td) | ||
mlir_tablegen(TritonTilingExtOpsDialect.h.inc -gen-dialect-decls -dialect=ttx) | ||
mlir_tablegen(TritonTilingExtOpsDialect.cpp.inc -gen-dialect-defs -dialect=ttx) | ||
mlir_tablegen(TritonTilingExtOps.h.inc -gen-op-decls) | ||
mlir_tablegen(TritonTilingExtOps.cpp.inc -gen-op-defs) | ||
add_public_tablegen_target(TritonTilingExtOpsIncGen) | ||
|
||
set(LLVM_TARGET_DEFINITIONS TritonTilingExtInterfaces.td) | ||
mlir_tablegen(TritonTilingExtInterfaces.h.inc -gen-op-interface-decls) | ||
mlir_tablegen(TritonTilingExtInterfaces.cpp.inc -gen-op-interface-defs) | ||
add_public_tablegen_target(TritonTilingExtInterfacesIncGen) |
107 changes: 107 additions & 0 deletions
107
include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ | ||
#define MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ | ||
|
||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Dialect.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
#include "mlir/IR/SymbolTable.h" | ||
#include "mlir/IR/TypeSupport.h" | ||
#include "mlir/IR/Types.h" | ||
#include "mlir/Interfaces/DestinationStyleOpInterface.h" | ||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
#include "mlir/Interfaces/TilingInterface.h" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// TritonTilingExt Operations | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOpsDialect.h.inc" | ||
|
||
// Include the generated interface declarations. | ||
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.h.inc" | ||
|
||
// Include the auto-generated header file containing the declarations of the | ||
// TritonTilingExt operations. | ||
#define GET_OP_CLASSES | ||
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.h.inc" | ||
|
||
namespace mlir { | ||
|
||
namespace ttx { | ||
|
||
// ----------------------------------------------------------------------------- | ||
// BufferizableOpInterface | ||
// ----------------------------------------------------------------------------- | ||
// All TritonTilingExtOps need to support bufferization: the process of | ||
// allocating buffers for tensors, thereby converting inputs and outputs of | ||
// tensor type to memref. This process is done by implementing the | ||
// "BufferizableOpInterface". We implement the interface for TritonTilingExtOps | ||
// through an external model instead of directly in TritonTilingExtOps.td to be | ||
// consistent with other ops in the mlir project. See some examples here: | ||
// - mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp | ||
// - mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp | ||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); | ||
|
||
// ----------------------------------------------------------------------------- | ||
// TilingInterface | ||
// ----------------------------------------------------------------------------- | ||
// The three methods `getTiledImplementation`, `getResultTilePosition`, and | ||
// `generateResultTileValue` are implemented as part of the TilingInterface. | ||
// (see TilingInterface.td). These three methods are re-used across | ||
// all TritonTilingExtOps, while others method are implemented individually by | ||
// each operator depending on their use cases. | ||
template <typename TritonTilingExtOpTy> | ||
FailureOr<TilingResult> getTiledImplementation(TritonTilingExtOpTy op, | ||
OpBuilder &b, | ||
ArrayRef<OpFoldResult> offsets, | ||
ArrayRef<OpFoldResult> sizes); | ||
|
||
template <typename TritonTilingExtOpTy> | ||
LogicalResult getResultTilePosition(TritonTilingExtOpTy op, OpBuilder &b, | ||
unsigned resultNumber, | ||
ArrayRef<OpFoldResult> offsets, | ||
ArrayRef<OpFoldResult> sizes, | ||
SmallVector<OpFoldResult> &resultOffsets, | ||
SmallVector<OpFoldResult> &resultSizes); | ||
|
||
template <typename TritonTilingExtOpTy> | ||
FailureOr<TilingResult> | ||
generateResultTileValue(TritonTilingExtOpTy op, OpBuilder &b, | ||
unsigned resultNumber, ArrayRef<OpFoldResult> offsets, | ||
ArrayRef<OpFoldResult> sizes); | ||
|
||
// ----------------------------------------------------------------------------- | ||
// MemoryEffectsOpInterface | ||
// ----------------------------------------------------------------------------- | ||
// Implementation of the MemoryEffectsOpInterface for TritonTilingExtOps. | ||
// This allows DCE pass to determine if a TritonTilingExtOp is safe to be | ||
// removed. see TritonTilingExtOps.td for more details. | ||
template <typename TritonTilingExtOpTy> | ||
void getEffects( | ||
TritonTilingExtOpTy op, | ||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | ||
&effects); | ||
|
||
// ----------------------------------------------------------------------------- | ||
// Utilities | ||
// ----------------------------------------------------------------------------- | ||
// Utility method to extract a slice from the input source using either | ||
// tensor::ExtractSlice or memref::SubView | ||
Value getSlice(OpBuilder &b, Location loc, Value source, | ||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, | ||
ArrayRef<OpFoldResult> strides); | ||
|
||
} // namespace ttx | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_ |
102 changes: 102 additions & 0 deletions
102
include/triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES | ||
#define MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES | ||
|
||
include "mlir/IR/OpBase.td" | ||
|
||
// | ||
// Linalg operators require providing affine maps that define how input / output | ||
// buffers are accessed together with a region that defines how each output | ||
// element is computed; this requirement doesn't work well for operations such as | ||
// `scan`. | ||
// | ||
// Fortunately, the introduction of the TilingInterface allows us to add tiling | ||
// and fusion support to operations that don't fit into the linalg dialect. | ||
// This fits our purpose perfectly: our `scan` operators can be treated as an | ||
// "opaque" / "completely abstract" operation that can be tiled on the batch | ||
// dimensions -- we don't need to provide any associated body together with it. | ||
// | ||
// However, this doesn't mean that we entirely forgo the "indexing map" concept. | ||
// For example, consider the following: | ||
// | ||
// - ttx.scan ins(%1 : tensor<128x768xbf16>) | ||
// outs(%2 : tensor<128x768xbf16>) -> tensor<128x768xbf16> | ||
// | ||
// Tiling the batch dimension gives us: | ||
// | ||
// for (i = 0 to 128) { | ||
// %sliceIn = extract slice from input: tensor<1x768xbf16> | ||
// %sliceOut = extract slice from output: tensor<1x768xbf16> | ||
// %res = ttx.scan ins(slice : tensor<1x768xbf16>) | ||
// outs(%2 : tensor<1x768xbf16>) -> tensor<1x768xbf16> | ||
// insert %res into output | ||
// } | ||
// | ||
// Now our `scan` op has the semantic of running `scan` on a rank-1 tensor and | ||
// can be lowered further to other hardware-specific ops or external library | ||
// calls. | ||
// | ||
// This tiling pattern is essentially the same as tiling a linalg.generic op | ||
// with an identity map. The only difference is we don't need a body associated | ||
// with our `scan` op. | ||
// | ||
// With this idea in mind, the TritonTilingExtInterface exposes methods | ||
// that will be implemented individually by each TritonTilingExtOp, providing | ||
// the indexing map for each input / output that can then be used to generate | ||
// the correct slices during tiling and fusion. | ||
// | ||
// There might be other ops in the future that won't fit in this "indexing map" | ||
// approach; we will consider making TritonTilingExtInterface an optional | ||
// interface for such ops. | ||
// | ||
|
||
def TritonTilingExtInterface : OpInterface<"TritonTilingExtInterface"> { | ||
let cppNamespace = "::mlir::ttx"; | ||
let methods = [ | ||
InterfaceMethod< | ||
/*desc=*/[{ | ||
Return the indexing map for the input operand with the given `index`. | ||
The `tileSizes` input indicates the requested tile size during tiling | ||
in case the indexing map for the operator is dependent on it. | ||
}], | ||
/*retTy=*/"AffineMap", | ||
/*methodName=*/"getInputIndexingMap", | ||
/*args=*/(ins "MLIRContext*":$context, | ||
"unsigned int":$index, | ||
"ArrayRef<OpFoldResult>":$tileSizes) | ||
>, | ||
InterfaceMethod< | ||
/*desc=*/[{ | ||
Return the indexing map for the output operand with the given `index`. | ||
The `tileSizes` input indicates the requested tile size during tiling | ||
in case the indexing map for the operator is dependent on it. | ||
}], | ||
/*retTy=*/"AffineMap", | ||
/*methodName=*/"getOutputIndexingMap", | ||
/*args=*/(ins "MLIRContext*":$context, | ||
"unsigned int":$index, | ||
"ArrayRef<OpFoldResult>":$tileSizes) | ||
>, | ||
InterfaceMethod< | ||
/*desc=*/[{ | ||
Return the indexing map for the operand with the given `index`. | ||
This method returns the operand in order of inputs followed by outputs. | ||
The `tileSizes` input indicates the requested tile size during tiling | ||
in case the indexing map for the operator is dependent on it. | ||
}], | ||
/*retTy=*/"AffineMap", | ||
/*methodName=*/"getIndexingMap", | ||
/*args=*/(ins "MLIRContext*":$context, | ||
"unsigned int":$index, | ||
"ArrayRef<OpFoldResult>":$tileSizes) | ||
> | ||
]; | ||
} | ||
|
||
#endif |
Oops, something went wrong.