Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIRSegmentLoopFusion: Fixups on affine::DelinearizeIndexOp and rank reduction #752

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
#include <string>
#include <vector>

#include <iostream>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. PR updated. Thanks for checking!

using namespace mlir;
using namespace xilinx;
using namespace xilinx::air;
Expand Down Expand Up @@ -4559,9 +4561,10 @@ struct ShrinkMemrefSizesByAccessPattern
auto shrunkMemrefType =
MemRefType::get(overall_access_bounds, elemType, nullptr, memorySpace);
MemRefType inferredSubViewOutputTy =
llvm::cast<MemRefType>(memref::SubViewOp::inferResultType(
shrunkMemrefType, subViewOp.getStaticOffsets(),
subViewOp.getStaticSizes(), subViewOp.getStaticStrides()));
llvm::cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
subViewOp.getType().getShape(), shrunkMemrefType,
subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
subViewOp.getStaticStrides()));
// Case 1: static size mismatches the shrunk shape.
for (unsigned i = 0; i < static_sizes.size(); i++) {
if (static_sizes[i] < 0) {
Expand Down
40 changes: 28 additions & 12 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,28 +1242,44 @@ static void updateAccessPatternByScfForNest(
&pattern,
SmallVector<Value> indices, OpBuilder builder) {
auto loc = builder.getUnknownLoc();
auto updateWrapAndStride = [&](Value index, int i) {
if (auto scfForOp = scf::getForInductionVarOwner(index)) {
std::get<1>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, *air::getStaticScfForTripCountAsInt(scfForOp));
std::get<2>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, (*getConstantIntValue(scfForOp.getStep())) *
(*getConstantIntValue(std::get<2>(pattern)[i])));

scfForOp.getStep();
auto updateWrapAndStride = [&](int stepSize, int tripCount, int i) {
std::get<1>(pattern)[i] =
builder.create<arith::ConstantIndexOp>(loc, tripCount);
std::get<2>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
loc, stepSize * (*getConstantIntValue(std::get<2>(pattern)[i])));
};
// Infer data access pattern's sizes from parent scf.for loop and any affine
// op applied on the induction variable
auto inferDataAccessSizes = [](scf::ForOp scfForOp, air::ExecuteOp execOp,
Value index) {
int scfForTripCount = *air::getStaticScfForTripCountAsInt(scfForOp);
// If scf.for's iv applies affine::DelinerizeIndexOp
if (auto delinearizeOp =
dyn_cast<affine::AffineDelinearizeIndexOp>(execOp.getChildOp())) {
int resIdx =
llvm::find(execOp.getResults(), index) - execOp.getResults().begin();
scfForTripCount = *getConstantIntValue(delinearizeOp.getBasis()[resIdx]);
}
return scfForTripCount;
};
int dim = -1;
for (auto index : indices) {
dim++;
if (getConstantIntValue(index))
continue;
updateWrapAndStride(index, dim);
if (auto scfForOp = scf::getForInductionVarOwner(index))
updateWrapAndStride(*getConstantIntValue(scfForOp.getStep()),
*air::getStaticScfForTripCountAsInt(scfForOp), dim);
if (!index.getDefiningOp())
continue;
if (auto execOp = dyn_cast<air::ExecuteOp>(index.getDefiningOp()))
if (auto execOp = dyn_cast<air::ExecuteOp>(index.getDefiningOp())) {
for (auto oper : execOp.getChildOp()->getOperands())
updateWrapAndStride(oper, dim);
if (auto scfForOp = scf::getForInductionVarOwner(oper)) {
int scfForTripCount = inferDataAccessSizes(scfForOp, execOp, index);
updateWrapAndStride(*getConstantIntValue(scfForOp.getStep()),
scfForTripCount, dim);
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,63 @@ func.func @func10(%arg0: memref<8x512xi32>, %arg1: memref<256x512xi32>, %arg2: m
}
return
}

// Affine::DelinearizeIndexOp support; rank-reduced memref::SubViewOp.

// CHECK-LABEL: func.func @func11
// CHECK: air.herd
// CHECK: %[[SUBVIEW0:.*]] = memref.subview{{.*}} : memref<16x16x4x4xf32, 1 : i32> to memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32>
// CHECK: %[[SUBVIEW1:.*]] = memref.subview{{.*}} : memref<1x16x4xf32, 2 : i32> to memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32>
// CHECK: %[[SUBVIEW2:.*]] = memref.subview{{.*}} : memref<1x1x16x16x4x4xbf16, 2 : i32> to memref<1x1x4x4xbf16, strided<[4096, 4096, 4, 1], offset: ?>, 2 : i32>
// CHECK: linalg.generic{{.*}} ins(%[[SUBVIEW0]], %[[SUBVIEW1]] {{.*}}outs(%[[SUBVIEW2]]

#map17 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map18 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
func.func @func11(%arg0: memref<512x512xbf16>, %arg1: memref<512x16384xbf16>, %arg2: memref<512xf32>, %arg3: memref<512x16384xbf16>) {
%c4 = arith.constant 4 : index
%c128 = arith.constant 128 : index
%0 = air.launch async (%arg4, %arg5) in (%arg6=%c4, %arg7=%c128) attributes {id = 1 : i32} {
%1 = air.segment @matmul_elementwise_bf16_dispatch_0_matmul_512x16384x512_bf16xbf16xf32_0 async attributes {id = 2 : i32} {
%c2 = arith.constant 2 : index
%async_token, %results = air.execute -> (memref<2x2x16x16x4x4xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<2x2x16x16x4x4xbf16, 2 : i32>
air.execute_terminator %alloc : memref<2x2x16x16x4x4xbf16, 2 : i32>
}
%async_token_0, %results_1 = air.execute -> (memref<1x16x4xf32, 2 : i32>) {
%alloc = memref.alloc() : memref<1x16x4xf32, 2 : i32>
air.execute_terminator %alloc : memref<1x16x4xf32, 2 : i32>
}
%async_token_2, %results_3 = air.execute -> (memref<16x16x4x4xf32, 1 : i32>) {
%alloc = memref.alloc() : memref<16x16x4x4xf32, 1 : i32>
air.execute_terminator %alloc : memref<16x16x4x4xf32, 1 : i32>
}
%2 = air.herd @herd_0 async tile (%arg8, %arg9) in (%arg10=%c2, %arg11=%c2) args(%arg12=%results_3, %arg13=%results_1, %arg14=%results) : memref<16x16x4x4xf32, 1 : i32>, memref<1x16x4xf32, 2 : i32>, memref<2x2x16x16x4x4xbf16, 2 : i32> {
%c16 = arith.constant 16 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c256 = arith.constant 256 : index
%3 = air.wait_all async
%4 = scf.for %arg15 = %c0 to %c256 step %c1 iter_args(%arg16 = %3) -> (!air.async.token) {
%async_token_4, %results_5:2 = air.execute [%arg16] -> (index, index) {
%6:2 = affine.delinearize_index %arg15 into (%c16, %c16) : index, index
air.execute_terminator %6#0, %6#1 : index, index
}
%subview = memref.subview %arg12[%results_5#0, %results_5#1, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<16x16x4x4xf32, 1 : i32> to memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32>
%subview_6 = memref.subview %arg13[0, %results_5#1, 0] [1, 1, 4] [1, 1, 1] : memref<1x16x4xf32, 2 : i32> to memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32>
%subview_7 = memref.subview %arg14[%arg8, %arg9, %results_5#0, %results_5#1, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<2x2x16x16x4x4xbf16, 2 : i32> to memref<1x1x4x4xbf16, strided<[256, 16, 4, 1], offset: ?>, 2 : i32>
%async_token_8 = air.execute [%arg16] {
linalg.generic {indexing_maps = [#map17, #map18, #map17], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview_6 : memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32>, memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32>) outs(%subview_7 : memref<1x1x4x4xbf16, strided<[256, 16, 4, 1], offset: ?>, 2 : i32>) {
^bb0(%in: f32, %in_9: f32, %out: bf16):
%6 = arith.addf %in, %in_9 : f32
%7 = arith.truncf %6 : f32 to bf16
linalg.yield %7 : bf16
}
}
%5 = air.wait_all async [%async_token_4, %async_token_8]
scf.yield %5 : !air.async.token
}
}
}
}
return
}
Loading