Skip to content

Commit

Permalink
Memref shrinkage by analyzing the air.channel data access pattern (Xi…
Browse files Browse the repository at this point in the history
…linx#451)

* After -air-loop-fusion, check for redundant memref allocation by analyzing the new data access pattern

* Update test after wrap-and-stride canonicalizer becomes more conservative

* Memref shrinkage unit test
  • Loading branch information
erwei-xilinx authored Feb 24, 2024
1 parent fb243b2 commit 3c1c256
Show file tree
Hide file tree
Showing 5 changed files with 390 additions and 32 deletions.
20 changes: 10 additions & 10 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2433,37 +2433,37 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder,
// default order.
int max_dim_size =
std::max(std::max(offsets.size(), sizes.size()), strides.size());
if (max_dim_size && offsets.size() < (unsigned)memref.getRank()) {
for (unsigned i = offsets.size(); i < memref.getRank(); i++) {
int target_dim_size = std::max(max_dim_size, (int)memref.getRank());
if (max_dim_size && offsets.size() < target_dim_size) {
for (unsigned i = offsets.size(); i < target_dim_size; i++) {
offsets.insert(offsets.begin(), builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), 0));
}
}
if (max_dim_size && sizes.size() < (unsigned)memref.getRank()) {
for (unsigned i = sizes.size(); i < memref.getRank(); i++) {
if (max_dim_size && sizes.size() < target_dim_size) {
for (unsigned i = sizes.size(); i < target_dim_size; i++) {
sizes.insert(sizes.begin(), builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), 1));
}
}
int memref_size = 1;
for (auto size : memref.getShape())
memref_size *= size;
if (max_dim_size && strides.size() < (unsigned)memref.getRank()) {
for (unsigned i = strides.size(); i < memref.getRank(); i++) {
if (max_dim_size && strides.size() < target_dim_size) {
for (unsigned i = strides.size(); i < target_dim_size; i++) {
strides.insert(strides.begin(),
builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), memref_size));
}
}

// Reduce highest dimensions if more than memref size
while (strides.size() > (unsigned)memref.getRank() &&
getConstantIntValue(strides[0]) &&
while (strides.size() > target_dim_size && getConstantIntValue(strides[0]) &&
*getConstantIntValue(strides[0]) == memref_size) {
strides.erase(strides.begin());
}
while (sizes.size() > (unsigned)memref.getRank() &&
getConstantIntValue(sizes[0]) && *getConstantIntValue(sizes[0]) == 1) {
while (sizes.size() > target_dim_size && getConstantIntValue(sizes[0]) &&
*getConstantIntValue(sizes[0]) == 1) {
sizes.erase(sizes.begin());
}
while (offsets.size() > std::min(sizes.size(), strides.size()) &&
Expand Down
Loading

0 comments on commit 3c1c256

Please sign in to comment.