Skip to content

Commit

Permalink
[LinalgExt] Add online_attention op (#17536)
Browse files Browse the repository at this point in the history
This patch adds a new online_attention op. This op represents a
partially reduced attention op which can be tiled along it's k2
reduction dimension. This op also has indexing maps, supports tiling on
all dimensions other than k1 dimension, and can decompose based on any
given indexing maps.

This patch also makes the CPU backend use online attention to decompose
and tile reduction dimension, allowing it to be tiled along N and batch
dimensions, and tiling using LLVMCPUTile.
  • Loading branch information
Groverkss authored Jun 12, 2024
1 parent 52b21f8 commit abf0087
Show file tree
Hide file tree
Showing 25 changed files with 954 additions and 87 deletions.
42 changes: 22 additions & 20 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ getDefaultDistributedLevelTileSizes(Operation *op,
/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
/// reduction loops.
static void splitParallelAndReductionTiles(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &parallelSizes,
Operation *op, SmallVectorImpl<int64_t> &parallelSizes,
SmallVectorImpl<int64_t> &reductionSizes,
SmallVectorImpl<bool> *parallelScalableFlags = nullptr,
SmallVectorImpl<bool> *reductionScalableFlags = nullptr) {
Expand All @@ -900,8 +900,9 @@ static void splitParallelAndReductionTiles(
reductionScalableFlags->assign(parallelScalableFlags->begin(),
parallelScalableFlags->end());
}
TilingInterface tilingOp = cast<TilingInterface>(op);
for (auto [index, iteratorType] :
llvm::enumerate(op.getIteratorTypesArray())) {
llvm::enumerate(tilingOp.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::parallel) {
reductionSizes[index] = 0;
if (reductionScalableFlags)
Expand Down Expand Up @@ -1121,9 +1122,9 @@ setMatmulRootConfig(mlir::FunctionOpInterface entryPointFn,
SmallVector<int64_t> parallelTileSizes = vecTileSizes;
SmallVector<int64_t> reductionTileSizes;
SmallVector<bool> reductionScalableFlags;
splitParallelAndReductionTiles(
cast<linalg::LinalgOp>(op.getOperation()), parallelTileSizes,
reductionTileSizes, &parallelScalableFlags, &reductionScalableFlags);
splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes,
&parallelScalableFlags,
&reductionScalableFlags);

if (vecPreProcStrategy == VectorPreProcStrategy::None) {
setVectorSizesForDynamicShapes(cast<linalg::LinalgOp>(op.getOperation()),
Expand Down Expand Up @@ -1751,14 +1752,13 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,

// Batch, M and N (parallel dimensions) are distributed on workgroups.
DistributionHeuristicConfig config;
SmallVector<int64_t> distTileSizes = getDefaultDistributedLevelTileSizes(
attnOp, DistributionHeuristicConfig{});
SmallVector<int64_t> distTileSizes =
getDefaultDistributedLevelTileSizes(attnOp, config);

// Batch, M and N (parallel dimensions) are distributed on workgroups.
SmallVector<int64_t> vecTileSizes(attnOp.getIterationDomainRank(), 1);
// Mark reduction dimensions not to distribute.
for (int64_t i :
llvm::concat<const int64_t>(opInfo.getK1Dims(), opInfo.getK2Dims())) {
// Mark k1 reduction dimensions not to distribute.
for (int i : opInfo.getK1Dims()) {
vecTileSizes[i] = 0;
}
int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType());
Expand All @@ -1773,18 +1773,17 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
/*numElem=*/tileSize, vectorSize, vectorSize);
}

// TODO (17467): Due to a bug in TileAndDecomposeAttention, N dimension
// cannot be tiled. Remove this once fixed.
for (int64_t i : opInfo.getNDims()) {
distTileSizes[i] = 0;
vecTileSizes[i] = 0;
}
SmallVector<int64_t> parallelTileSizes = vecTileSizes;
SmallVector<int64_t> reductionTileSizes;
splitParallelAndReductionTiles(attnOp, parallelTileSizes, reductionTileSizes);

TileSizesListType tileSizes = {distTileSizes, vecTileSizes};
LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (parallel): "
<< parallelTileSizes << "\n");
LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (reduction): "
<< reductionTileSizes << "\n");

// TODO: (Groverkss): Tile K2 here using reduction tiling interface once we
// have it. TileAndDecomposeAttention pass only tiles K2. I think it should
// be possible to tile K1 also, but need to explore it more.
TileSizesListType tileSizes = {distTileSizes, parallelTileSizes,
reductionTileSizes};

return setOpConfigAndEntryPointFnTranslation(
entryPointFn, attnOp, tileSizes,
Expand Down Expand Up @@ -1843,6 +1842,9 @@ setWinogradRootConfig(mlir::FunctionOpInterface entryPointFn,
tileSizes.push_back(distTileSizes);
SmallVector<int64_t> vecTileSizes(iterationRank, 1);
tileSizes.push_back(vecTileSizes);
// Dummy tiling config for reduction level.
SmallVector<int64_t> reductionTileSizes(iterationRank, 0);
tileSizes.push_back(reductionTileSizes);
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, winogradOp, tileSizes,
DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize);
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,13 @@ void addCPULinalgExtTileAndVectorizePipeline(
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
// TODO: Remove the pass once we have PartialReductionOpInterface implemented
// for AttentionOp.
funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
funcPassManager.addPass(
createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());

{
GenericVectorizationPassOptions options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_output_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand All @@ -1556,7 +1556,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_input_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand All @@ -1581,7 +1581,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1], [0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_filter_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -1613,7 +1613,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 0], [20, 32, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 64], [20, 32, 0, 0, 32], [0, 0, 0, 32, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @attention()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down
9 changes: 7 additions & 2 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_td_library(
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:LinalgOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PDLDialectTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
Expand Down Expand Up @@ -159,7 +160,9 @@ iree_gentbl_cc_library(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
deps = [":td_files"],
deps = [
":td_files",
],
)

iree_gentbl_cc_library(
Expand Down Expand Up @@ -212,5 +215,7 @@ iree_tablegen_doc(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
deps = [":td_files"],
deps = [
":td_files",
],
)
78 changes: 78 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -1315,6 +1316,9 @@ LogicalResult AttentionOp::verify() {
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
Expand Down Expand Up @@ -1427,6 +1431,79 @@ SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
return results;
}

//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//

LogicalResult OnlineAttentionOp::verify() {
OnlineAttentionOp attnOp = *this;

SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();

// Check if indexing maps can represent attention.
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);

// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
SmallVector<bool> foundDims(getIterationDomainRank(), false);
auto checkShape = [&shape, &foundDims,
&attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
AffineMap indexingMap) -> LogicalResult {
if (indexingMap.getNumResults() != valShape.size()) {
return attnOp->emitError("Rank Mismatch for ")
<< operandName << ". Expected: " << indexingMap.getNumResults()
<< " Got: " << valShape.size();
}
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
}
if (shape[pos] != valShape[i]) {
return attnOp->emitError("Shape Mismatch for ")
<< operandName << ". Expected: " << shape[pos]
<< " Got: " << valShape[i];
}
}
return success();
};

if (failed(checkShape("Query", getQuery().getType().getShape(),
getQueryMap())) ||
failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
failed(checkShape("Value", getValue().getType().getShape(),
getValueMap())) ||
failed(checkShape("Output", getOutput().getType().getShape(),
getOutputMap())) ||
failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) ||
failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) {
return failure();
}

return success();
}

MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
}

LogicalResult OnlineAttentionOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}

SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
return SmallVector<AffineMap>(
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}

#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
Expand All @@ -1446,6 +1523,7 @@ DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)
DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)

} // namespace mlir::iree_compiler::IREE::LinalgExt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Attributes.h"
Expand Down
91 changes: 91 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -678,6 +679,96 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
}];
}

//===----------------------------------------------------------------------===//
// OnlineAttention
//===----------------------------------------------------------------------===//

def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Online Attention operator";
let description = [{
Traditional scaled dot product attention computes:

attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V

Online Attention on the other hand, uses an online normalizer instead of
softmax:

online_attention(Q, K, V, scale, running_max, running_sum)
= online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V

The advantage of this online_normalizer is that it can be tiled along
it's reduction dimension, making the online_attention operator:
- Tilable along softmax reduction dimension
- Associative along softmax reduction dimension
- Commutative along softmax associative dimension

Note: The results of online_attention need to be combined after computing
it over the entire softmax reduction dimension by:
x, _, sum : results
x = (1 / sum) * x
}];

let arguments = (ins AnyShaped:$query,
AnyShaped:$key,
AnyShaped:$value,
AnyFloat:$scale,
AnyShaped:$output,
AnyShaped:$max,
AnyShaped:$sum,
AffineMapArrayAttr:$indexing_maps
);

let results = (outs Variadic<AnyRankedTensor>:$results);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
`outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
(`->` type($results)^)?
}];

let extraClassDeclaration = [{
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable();

SmallVector<AffineMap> getIndexingMapsArray();

AffineMap getQueryMap() {
return getIndexingMapsArray()[0];
}
AffineMap getKeyMap() {
return getIndexingMapsArray()[1];
}
AffineMap getValueMap() {
return getIndexingMapsArray()[2];
}
AffineMap getOutputMap() {
return getIndexingMapsArray()[3];
}
AffineMap getMaxMap() {
return getIndexingMapsArray()[4];
}
AffineMap getSumMap() {
return getIndexingMapsArray()[5];
}

int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
}];
}

} // OpGroupNonStructuredOps

//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit abf0087

Please sign in to comment.