Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgExt] Add online_attention op #17536

Merged
merged 31 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9549091
Split tests
Groverkss May 22, 2024
872ca3e
Address comments
Groverkss May 28, 2024
a3e3471
save
Groverkss May 28, 2024
b53734d
save
Groverkss May 29, 2024
3f35acb
add online attention op
Groverkss May 29, 2024
1e5c190
Implement TilingInterface for online attention
Groverkss May 29, 2024
c68d28f
refactor some impl
Groverkss May 29, 2024
e3f5896
Add aggregate op interface for online_attention
Groverkss May 31, 2024
5993feb
add dtype conversions and convert to online attention pass
Groverkss Jun 3, 2024
c10ff9c
remove redundant functions
Groverkss Jun 3, 2024
1765894
Make llvmcpu backend use online attention
Groverkss Jun 4, 2024
a69bf87
Remove redundant comments
Groverkss Jun 4, 2024
283f5cc
add test for tiling
Groverkss Jun 4, 2024
6feaa37
clang-format
Groverkss Jun 5, 2024
c3fb664
add decompose test
Groverkss Jun 5, 2024
0974aee
Add docs for online_attention
Groverkss Jun 5, 2024
2e96982
bazeltocamke
Groverkss Jun 5, 2024
8835a84
remove todo
Groverkss Jun 5, 2024
791a31a
address comments
Groverkss Jun 7, 2024
f937879
Move aggregate op implementation to seperate file
Groverkss Jun 7, 2024
5d6f8cc
addreess comments
Groverkss Jun 7, 2024
b52d70d
fix compilation error
Groverkss Jun 10, 2024
4f0a7e9
Address hanhan's comments
Groverkss Jun 11, 2024
f490be5
pre-commit
Groverkss Jun 11, 2024
ac149e3
dummy reduction tile sizes for winograd
Groverkss Jun 11, 2024
ec7aff2
fix tests
Groverkss Jun 11, 2024
c249c98
fix test
Groverkss Jun 12, 2024
713c95b
BAZEL :cry:
Groverkss Jun 12, 2024
5157ed5
Revert "BAZEL :cry:"
Groverkss Jun 12, 2024
01fca7d
BAZEL BAZEL
Groverkss Jun 12, 2024
3e2f54f
BEZELL
Groverkss Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ getDefaultDistributedLevelTileSizes(Operation *op,
/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the
/// reduction loops.
static void splitParallelAndReductionTiles(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &parallelSizes,
Operation *op, SmallVectorImpl<int64_t> &parallelSizes,
SmallVectorImpl<int64_t> &reductionSizes,
SmallVectorImpl<bool> *parallelScalableFlags = nullptr,
SmallVectorImpl<bool> *reductionScalableFlags = nullptr) {
Expand All @@ -900,8 +900,9 @@ static void splitParallelAndReductionTiles(
reductionScalableFlags->assign(parallelScalableFlags->begin(),
parallelScalableFlags->end());
}
TilingInterface tilingOp = cast<TilingInterface>(op);
for (auto [index, iteratorType] :
llvm::enumerate(op.getIteratorTypesArray())) {
llvm::enumerate(tilingOp.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::parallel) {
reductionSizes[index] = 0;
if (reductionScalableFlags)
Expand Down Expand Up @@ -1121,9 +1122,9 @@ setMatmulRootConfig(mlir::FunctionOpInterface entryPointFn,
SmallVector<int64_t> parallelTileSizes = vecTileSizes;
SmallVector<int64_t> reductionTileSizes;
SmallVector<bool> reductionScalableFlags;
splitParallelAndReductionTiles(
cast<linalg::LinalgOp>(op.getOperation()), parallelTileSizes,
reductionTileSizes, &parallelScalableFlags, &reductionScalableFlags);
splitParallelAndReductionTiles(op, parallelTileSizes, reductionTileSizes,
&parallelScalableFlags,
&reductionScalableFlags);

if (vecPreProcStrategy == VectorPreProcStrategy::None) {
setVectorSizesForDynamicShapes(cast<linalg::LinalgOp>(op.getOperation()),
Expand Down Expand Up @@ -1751,14 +1752,13 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,

// Batch, M and N (parallel dimensions) are distributed on workgroups.
DistributionHeuristicConfig config;
SmallVector<int64_t> distTileSizes = getDefaultDistributedLevelTileSizes(
attnOp, DistributionHeuristicConfig{});
SmallVector<int64_t> distTileSizes =
getDefaultDistributedLevelTileSizes(attnOp, config);

// Batch, M and N (parallel dimensions) are distributed on workgroups.
SmallVector<int64_t> vecTileSizes(attnOp.getIterationDomainRank(), 1);
// Mark reduction dimensions not to distribute.
for (int64_t i :
llvm::concat<const int64_t>(opInfo.getK1Dims(), opInfo.getK2Dims())) {
// Mark k1 reduction dimensions not to distribute.
for (int i : opInfo.getK1Dims()) {
vecTileSizes[i] = 0;
}
int64_t vectorSize = getVectorSize(entryPointFn, attnOp.getOutputType());
Expand All @@ -1773,18 +1773,17 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
/*numElem=*/tileSize, vectorSize, vectorSize);
}

// TODO (17467): Due to a bug in TileAndDecomposeAttention, N dimension
// cannot be tiled. Remove this once fixed.
for (int64_t i : opInfo.getNDims()) {
distTileSizes[i] = 0;
vecTileSizes[i] = 0;
}
SmallVector<int64_t> parallelTileSizes = vecTileSizes;
SmallVector<int64_t> reductionTileSizes;
splitParallelAndReductionTiles(attnOp, parallelTileSizes, reductionTileSizes);

TileSizesListType tileSizes = {distTileSizes, vecTileSizes};
LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (parallel): "
<< parallelTileSizes << "\n");
LLVM_DEBUG(KD_DBGS() << "Vectorization/unrolling tile sizes (reduction): "
<< reductionTileSizes << "\n");

// TODO: (Groverkss): Tile K2 here using reduction tiling interface once we
// have it. TileAndDecomposeAttention pass only tiles K2. I think it should
// be possible to tile K1 also, but need to explore it more.
TileSizesListType tileSizes = {distTileSizes, parallelTileSizes,
reductionTileSizes};

return setOpConfigAndEntryPointFnTranslation(
entryPointFn, attnOp, tileSizes,
Expand Down Expand Up @@ -1843,6 +1842,9 @@ setWinogradRootConfig(mlir::FunctionOpInterface entryPointFn,
tileSizes.push_back(distTileSizes);
SmallVector<int64_t> vecTileSizes(iterationRank, 1);
tileSizes.push_back(vecTileSizes);
// Dummy tiling config for reduction level.
SmallVector<int64_t> reductionTileSizes(iterationRank, 0);
tileSizes.push_back(reductionTileSizes);
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, winogradOp, tileSizes,
DispatchLoweringPassPipeline::CPULinalgExtTileAndVectorize);
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,13 @@ void addCPULinalgExtTileAndVectorizePipeline(
createLLVMCPUTilePass(tilingConfig.getVectorCommonParallelLevel()));
// TODO: Remove the pass once we have PartialReductionOpInterface implemented
// for AttentionOp.
funcPassManager.addPass(IREE::LinalgExt::createTileAttentionPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());
Comment on lines +600 to +601
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this happen before tiling parallel dims? Is it a requirement for tiling reduction loops?

Copy link
Contributor Author

@Groverkss Groverkss Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can only tile reduction loops on online_attention op. We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass. For more context, the conversion does:

attention { lowering_config }

to

acc = acc_fill
max = max_fill
sum = sum_fill
out:3 = online_attention acc, max, sum {lowering_config}
elementwise out#0, out#2

The lowering config gets preserved on the online_attention op and is used for reduction tiling. Until we have consumer fusion (and greedy fusion for multiple operands/results) fixed, I don't think we can do it.

As a side note, this doesn't allow us to do further levels of parallel tiling on the elementwise and fill operations (which is not the best).

Copy link
Contributor Author

@Groverkss Groverkss Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I would like there to be a way to propagate the lowering_config attribute when I do a conversion like this (which would be putting the tiling information on the type, or somewhere more presistent).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this before tiling parallel dims, but we would then need to propagate lowering_config info in createConvertAttentionToOnlineAttention pass.

It is more like asking questions but not a requirement to address the comment. I'm trying to see the whole picture of how it could be done in CPU backend.

So it seems that we can convert the op to online_attention op before lowering strategy selection, like what we've done in softmax op. Do you think that we want to keep it as attention form when we're doing the tiling on parallel loops? Or it does not matter if we have "tile online_attention op and fuse its producers/consumers into the for loop"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I understand what you mean now. I can try. I'm thinking there might be problems with fusion because online_attention op has multiple results. Let me try and see if I can do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to try it and land it in this PR, because the PR is already big and it is fairly new to CPU backends. I can pull in others to help with CPU changes later. Are there other pending changes for attention ops?

funcPassManager.addPass(
createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());

{
GenericVectorizationPassOptions options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_output_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand All @@ -1556,7 +1556,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[0, 1, 6, 64], [1, 1, 1, 1], [0, 0, 0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_input_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand All @@ -1581,7 +1581,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[8, 64], [1, 1], [0, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @winograd_filter_transform()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down Expand Up @@ -1613,7 +1613,7 @@ module {
return
}
}
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 0], [20, 32, 0, 0, 0]]>
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[20, 64, 0, 0, 64], [20, 32, 0, 0, 32], [0, 0, 0, 32, 0]]>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPULinalgExtTileAndVectorize>
// CHECK: func.func @attention()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
Expand Down
9 changes: 7 additions & 2 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_td_library(
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:LinalgOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PDLDialectTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
Expand Down Expand Up @@ -159,7 +160,9 @@ iree_gentbl_cc_library(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
deps = [":td_files"],
deps = [
":td_files",
],
)

iree_gentbl_cc_library(
Expand Down Expand Up @@ -212,5 +215,7 @@ iree_tablegen_doc(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "LinalgExtOps.td",
deps = [":td_files"],
deps = [
":td_files",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -1315,6 +1316,9 @@ LogicalResult AttentionOp::verify() {
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
Expand Down Expand Up @@ -1427,6 +1431,79 @@ SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
return results;
}

//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//

LogicalResult OnlineAttentionOp::verify() {
OnlineAttentionOp attnOp = *this;

SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();

// Check if indexing maps can represent attention.
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(indexingMaps);
Groverkss marked this conversation as resolved.
Show resolved Hide resolved

// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
SmallVector<bool> foundDims(getIterationDomainRank(), false);
auto checkShape = [&shape, &foundDims,
&attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
AffineMap indexingMap) -> LogicalResult {
if (indexingMap.getNumResults() != valShape.size()) {
return attnOp->emitError("Rank Mismatch for ")
<< operandName << ". Expected: " << indexingMap.getNumResults()
<< " Got: " << valShape.size();
}
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
}
if (shape[pos] != valShape[i]) {
return attnOp->emitError("Shape Mismatch for ")
<< operandName << ". Expected: " << shape[pos]
<< " Got: " << valShape[i];
}
}
return success();
};

if (failed(checkShape("Query", getQuery().getType().getShape(),
getQueryMap())) ||
failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
failed(checkShape("Value", getValue().getType().getShape(),
getValueMap())) ||
failed(checkShape("Output", getOutput().getType().getShape(),
getOutputMap())) ||
failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) ||
failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) {
return failure();
Groverkss marked this conversation as resolved.
Show resolved Hide resolved
}

return success();
}

MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
return MutableOperandRange(*this, /*numInputs=*/4, /*numInits=*/3);
}

LogicalResult OnlineAttentionOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}

SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
return SmallVector<AffineMap>(
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}

#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
Expand All @@ -1446,6 +1523,7 @@ DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)
DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)

} // namespace mlir::iree_compiler::IREE::LinalgExt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Attributes.h"
Expand Down
91 changes: 91 additions & 0 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -678,6 +679,96 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
}];
}

//===----------------------------------------------------------------------===//
// OnlineAttention
//===----------------------------------------------------------------------===//

def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Online Attention operator";
let description = [{
Traditional scaled dot product attention computes:

attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V

Online Attention on the other hand, uses an online normalizer instead of
softmax:

online_attention(Q, K, V, scale, running_max, running_sum)
= online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V

The advantage of this online_normalizer is that it can be tiled along
it's reduction dimension, making the online_attention operator:
- Tilable along softmax reduction dimension
- Associative along softmax reduction dimension
- Commutative along softmax associative dimension

Note: The results of online_attention need to be combined after computing
it over the entire softmax reduction dimension by:
x, _, sum : results
x = (1 / sum) * x
}];

let arguments = (ins AnyShaped:$query,
AnyShaped:$key,
AnyShaped:$value,
AnyFloat:$scale,
AnyShaped:$output,
AnyShaped:$max,
AnyShaped:$sum,
AffineMapArrayAttr:$indexing_maps
);

let results = (outs Variadic<AnyRankedTensor>:$results);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale `:` type($query) `,` type($key) `,` type($value) `,` type($scale) `)`
`outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
(`->` type($results)^)?
}];

let extraClassDeclaration = [{
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable();

SmallVector<AffineMap> getIndexingMapsArray();

AffineMap getQueryMap() {
return getIndexingMapsArray()[0];
}
AffineMap getKeyMap() {
return getIndexingMapsArray()[1];
}
AffineMap getValueMap() {
return getIndexingMapsArray()[2];
}
AffineMap getOutputMap() {
return getIndexingMapsArray()[3];
}
AffineMap getMaxMap() {
return getIndexingMapsArray()[4];
}
AffineMap getSumMap() {
return getIndexingMapsArray()[5];
}

int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
}];
}

} // OpGroupNonStructuredOps

//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading