Skip to content

Commit

Permalink
Add TritonTilingExt dialect (#45)
Browse files Browse the repository at this point in the history
The `TritonTilingExt` dialect leverages linalg's `TilingInterface` to add tiling & fusion support for operators that cannot be represented in linalg Much of the tiling implementation is borrowed from linalg's tiling code. As an example, I have added a barebone version of `cumsum` that represents the cumulative sum of the inner-most dimension of a tensor. Other backends can then lower the op to lower-level implementation as needed after applying tiling & fusion.
  • Loading branch information
nhat-nguyen authored Nov 8, 2023
1 parent 6fa7ce3 commit 230c38b
Show file tree
Hide file tree
Showing 19 changed files with 1,282 additions and 13 deletions.
3 changes: 2 additions & 1 deletion include/triton-shared/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(Conversion)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
1 change: 1 addition & 0 deletions include/triton-shared/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(TritonTilingExt)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
11 changes: 11 additions & 0 deletions include/triton-shared/Dialect/TritonTilingExt/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
set(LLVM_TARGET_DEFINITIONS TritonTilingExtOps.td)
mlir_tablegen(TritonTilingExtOpsDialect.h.inc -gen-dialect-decls -dialect=ttx)
mlir_tablegen(TritonTilingExtOpsDialect.cpp.inc -gen-dialect-defs -dialect=ttx)
mlir_tablegen(TritonTilingExtOps.h.inc -gen-op-decls)
mlir_tablegen(TritonTilingExtOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(TritonTilingExtOpsIncGen)

set(LLVM_TARGET_DEFINITIONS TritonTilingExtInterfaces.td)
mlir_tablegen(TritonTilingExtInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TritonTilingExtInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(TritonTilingExtInterfacesIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_
#define MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"

//===----------------------------------------------------------------------===//
// TritonTilingExt Operations
//===----------------------------------------------------------------------===//

#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOpsDialect.h.inc"

// Include the generated interface declarations.
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtInterfaces.h.inc"

// Include the auto-generated header file containing the declarations of the
// TritonTilingExt operations.
#define GET_OP_CLASSES
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtOps.h.inc"

namespace mlir {

namespace ttx {

// -----------------------------------------------------------------------------
// BufferizableOpInterface
// -----------------------------------------------------------------------------
// All TritonTilingExtOps need to support bufferization: the process of
// allocating buffers for tensors, thereby converting inputs and outputs of
// tensor type to memref. This process is done by implementing the
// "BufferizableOpInterface". We implement the interface for TritonTilingExtOps
// through an external model instead of directly in TritonTilingExtOps.td to be
// consistent with other ops in the mlir project. See some examples here:
// - mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
// - mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

// -----------------------------------------------------------------------------
// TilingInterface
// -----------------------------------------------------------------------------
// The three methods `getTiledImplementation`, `getResultTilePosition`, and
// `generateResultTileValue` are implemented as part of the TilingInterface.
// (see TilingInterface.td). These three methods are re-used across
// all TritonTilingExtOps, while others method are implemented individually by
// each operator depending on their use cases.
template <typename TritonTilingExtOpTy>
FailureOr<TilingResult> getTiledImplementation(TritonTilingExtOpTy op,
OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes);

template <typename TritonTilingExtOpTy>
LogicalResult getResultTilePosition(TritonTilingExtOpTy op, OpBuilder &b,
unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes);

template <typename TritonTilingExtOpTy>
FailureOr<TilingResult>
generateResultTileValue(TritonTilingExtOpTy op, OpBuilder &b,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes);

// -----------------------------------------------------------------------------
// MemoryEffectsOpInterface
// -----------------------------------------------------------------------------
// Implementation of the MemoryEffectsOpInterface for TritonTilingExtOps.
// This allows DCE pass to determine if a TritonTilingExtOp is safe to be
// removed. see TritonTilingExtOps.td for more details.
template <typename TritonTilingExtOpTy>
void getEffects(
TritonTilingExtOpTy op,
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects);

// -----------------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------------
// Utility method to extract a slice from the input source using either
// tensor::ExtractSlice or memref::SubView
Value getSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides);

} // namespace ttx
} // namespace mlir

#endif // MLIR_DIALECT_TRITON_TILING_EXT_IR_TRITON_TILING_EXT_DIALECT_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES
#define MLIR_TRITON_TILING_EXT_DIALECT_INTERFACES

include "mlir/IR/OpBase.td"

//
// Linalg operators require providing affine maps that define how input / output
// buffers are accessed together with a region that defines how each output
// element is computed; this requirement doesn't work well for operations such as
// `scan`.
//
// Fortunately, the introduction of the TilingInterface allows us to add tiling
// and fusion support to operations that don't fit into the linalg dialect.
// This fits our purpose perfectly: our `scan` operators can be treated as an
// "opaque" / "completely abstract" operation that can be tiled on the batch
// dimensions -- we don't need to provide any associated body together with it.
//
// However, this doesn't mean that we entirely forgo the "indexing map" concept.
// For example, consider the following:
//
// - ttx.scan ins(%1 : tensor<128x768xbf16>)
// outs(%2 : tensor<128x768xbf16>) -> tensor<128x768xbf16>
//
// Tiling the batch dimension gives us:
//
// for (i = 0 to 128) {
// %sliceIn = extract slice from input: tensor<1x768xbf16>
// %sliceOut = extract slice from output: tensor<1x768xbf16>
// %res = ttx.scan ins(slice : tensor<1x768xbf16>)
// outs(%2 : tensor<1x768xbf16>) -> tensor<1x768xbf16>
// insert %res into output
// }
//
// Now our `scan` op has the semantic of running `scan` on a rank-1 tensor and
// can be lowered further to other hardware-specific ops or external library
// calls.
//
// This tiling pattern is essentially the same as tiling a linalg.generic op
// with an identity map. The only difference is we don't need a body associated
// with our `scan` op.
//
// With this idea in mind, the TritonTilingExtInterface exposes methods
// that will be implemented individually by each TritonTilingExtOp, providing
// the indexing map for each input / output that can then be used to generate
// the correct slices during tiling and fusion.
//
// There might be other ops in the future that won't fit in this "indexing map"
// approach; we will consider making TritonTilingExtInterface an optional
// interface for such ops.
//

def TritonTilingExtInterface : OpInterface<"TritonTilingExtInterface"> {
let cppNamespace = "::mlir::ttx";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the indexing map for the input operand with the given `index`.
The `tileSizes` input indicates the requested tile size during tiling
in case the indexing map for the operator is dependent on it.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getInputIndexingMap",
/*args=*/(ins "MLIRContext*":$context,
"unsigned int":$index,
"ArrayRef<OpFoldResult>":$tileSizes)
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing map for the output operand with the given `index`.
The `tileSizes` input indicates the requested tile size during tiling
in case the indexing map for the operator is dependent on it.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getOutputIndexingMap",
/*args=*/(ins "MLIRContext*":$context,
"unsigned int":$index,
"ArrayRef<OpFoldResult>":$tileSizes)
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing map for the operand with the given `index`.
This method returns the operand in order of inputs followed by outputs.
The `tileSizes` input indicates the requested tile size during tiling
in case the indexing map for the operator is dependent on it.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getIndexingMap",
/*args=*/(ins "MLIRContext*":$context,
"unsigned int":$index,
"ArrayRef<OpFoldResult>":$tileSizes)
>
];
}

#endif
Loading

0 comments on commit 230c38b

Please sign in to comment.