Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite. #10787

Merged
merged 2 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 95 additions & 2 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1405,12 +1405,25 @@ class StorageFlattener : public StmtExprMutator {
// rather than a buffer_var.
Stmt VisitStmt_(const AllocateNode* op) final {
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
auto stmt = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), stmt->condition,
stmt->body, stmt->annotations, stmt->span);
}

Stmt VisitStmt_(const AllocateConstNode* op) final {
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
auto stmt = Downcast<AllocateConst>(StmtExprMutator::VisitStmt_(op));
ObjectRef data_or_idx;
if (stmt->data) {
data_or_idx = stmt->data.value();
} else if (stmt->irmod_storage_idx) {
data_or_idx = stmt->irmod_storage_idx.value();
} else {
LOG(FATAL) << "Neither data array nor data index specified for allocation of const "
<< op->buffer_var->name_hint;
}
return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx,
stmt->body, stmt->span);
}

Stmt VisitStmt_(const LetStmtNode* op) final {
Expand Down Expand Up @@ -1598,6 +1611,82 @@ class StorageFlattener : public StmtExprMutator {
}

private:
// Helper function for visiting Allocate and AllocateConst. If, in
// the future, these are updated to hold a buffer (Buffer) object
// rather than a buffer_var (Var), this function can be replaced
// with a call to GetBufferEntry.
template <typename Node>
Array<PrimExpr> FlattenExtents(const Node& node) {
arith::Analyzer analyzer;

// If an allocation has extents that match the buffer
auto is_compatible_buffer = [&](const Buffer& buffer) {
if (buffer->shape.size() != node->extents.size()) {
return false;
}
for (size_t i = 0; i < buffer->shape.size(); i++) {
if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) {
return false;
}
}

return true;
};

auto int_array_equal = [](const Array<IntImm>& a, const Array<IntImm>& b) {
if (a.size() != b.size()) {
return false;
}

for (size_t i = 0; i < a.size(); i++) {
if (a[i]->value != b[i]->value) {
return false;
}
}

return true;
};

Array<IntImm> axis_separators;
auto it = buffer_var_map_.find(node->buffer_var.get());
if (it != buffer_var_map_.end()) {
const auto& buffers = it->second;
if (buffers.size() == 0) {
// No buffers use this allocation, treat as flat and optimize
// out later.
} else if (buffers.size() == 1) {
// Only one buffer uses this allocation, so use its axis
// separators.
axis_separators = buffers[0]->axis_separators;
} else {
// Try to find a buffer using this allocation with a matching
// shape.
Buffer compatible_buffer;
for (const auto& buffer : buffers) {
if (is_compatible_buffer(buffer)) {
ICHECK(!compatible_buffer.defined() ||
int_array_equal(compatible_buffer->axis_separators, buffer->axis_separators))
<< "Cannot determine axis separators to use when flattening "
<< node->buffer_var->name_hint
<< ", multiple buffer objects found with conflicting axis separators";
compatible_buffer = buffer;
}
}
ICHECK(compatible_buffer.defined())
<< "Cannot determine axis separators to use when flattening "
<< node->buffer_var->name_hint << ", no buffers found with matching shape";
axis_separators = compatible_buffer->axis_separators;
}
}

// Use GetFlattenedBuffer to determine the flattened shape of the
// output. We only need the shape and axis separators defined,
// everything else can be dummy values.
Buffer dummy_buffer =
decl_buffer(node->extents, DataType::Float(32), "buffer", "", axis_separators);
return dummy_buffer.GetFlattenedBuffer()->shape;
}

// The buffer entry in the flatten map
struct DimAlignInfo {
int align_factor{0};
Expand Down Expand Up @@ -1665,6 +1754,10 @@ class StorageFlattener : public StmtExprMutator {
// Set of vars that have occurred in an AllocateNode, but haven't
// yet occurred in a BufferLoad/BufferStore.
std::unordered_set<const VarNode*> buffer_var_defines_;
// Map from an allocation variable to the buffer(s) that it backs.
// Used to track the determine the axis_separators that should be
// used for flattening the extents of an AllocateNode.
std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_;
// Buffer map
std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
// The extern buffer map, updated to include flattened buffers.
Expand Down
77 changes: 60 additions & 17 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
};
// The scope of each allocation
struct AllocEntry {
// The physical dimension of the allocation.
size_t num_physical_dimensions{0};
// scope level
size_t level{0};
// allocation stmt
Expand All @@ -85,8 +87,16 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
void VisitStmt_(const AllocateNode* op) final {
size_t level = scope_.size();
const VarNode* buf = op->buffer_var.get();
alloc_info_[buf].alloc = op;
alloc_info_[buf].level = level;

AllocEntry entry;
entry.alloc = op;
entry.level = level;
// Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
// all allocations specify the extent of physical dimensions, and
// is 1 for flat memory spaces.
entry.num_physical_dimensions = op->extents.size();
alloc_info_[buf] = entry;

StmtExprVisitor::VisitStmt_(op);
}

Expand All @@ -104,6 +114,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
scope_[it->second.level].touched.push_back(buf);

ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
}
StmtEntry e = scope_.back();
scope_.pop_back();
Expand All @@ -125,6 +141,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
scope_[it->second.level].touched.push_back(buf);

ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
}
}

Expand Down Expand Up @@ -530,6 +552,10 @@ class StoragePlanRewriter : public StmtExprMutator {
uint64_t const_nbits{0};
// The storage scope.
StorageScope scope;
// The physical dimensionality of the allocations. Since
// StorageRewrite is applied after StorageFlatten/FlattenBuffer,
// this is size of `AllocateNode::extents`. If moved
size_t ndim;
// Allocs that shares this entry.
std::vector<const AllocateNode*> allocs;
// The children of this entry, not including itself.
Expand Down Expand Up @@ -629,8 +655,8 @@ class StoragePlanRewriter : public StmtExprMutator {
// simply use the original allocation.
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), e->allocs[0]->extents);
e->new_alloc =
Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0));
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
Expand All @@ -641,8 +667,13 @@ class StoragePlanRewriter : public StmtExprMutator {
// Build a merged allocation
PrimExpr combo_size;
for (const AllocateNode* op : e->allocs) {
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), op->extents);
ICHECK_EQ(op->extents.size(), 1)
<< "Buffer var " << op->buffer_var->name_hint
<< " was identified as a re-usable allocation, but has " << op->extents.size()
<< " physical dimensions. "
<< "Currently, only flat 1-d memory spaces should be identified as re-usable "
"allocations.";
PrimExpr sz = op->extents[0];
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto* imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
Expand Down Expand Up @@ -790,7 +821,8 @@ class StoragePlanRewriter : public StmtExprMutator {

for (const VarNode* var : it->second.gen) {
ICHECK(alloc_info.count(var));
const AllocateNode* alloc = alloc_info.at(var).alloc;
const AllocEntry& entry = alloc_info.at(var);
const AllocateNode* alloc = entry.alloc;
auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
StorageEntry* dst_entry = nullptr;
// inplace detection
Expand Down Expand Up @@ -818,7 +850,8 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
if (dst_entry == nullptr) {
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope);
dst_entry =
FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
Expand Down Expand Up @@ -871,24 +904,34 @@ class StoragePlanRewriter : public StmtExprMutator {
}

StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
const StorageScope& scope) {
const StorageScope& scope, size_t num_physical_dimensions) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);

// If the size of the array isn't known at compile-time, it must
// have its own allocation with size determined at runtime.
bool is_known_size = (const_nbits != 0);

// Currently, only flat memory spaces can be re-used. Packing
// into N-d space (e.g. 2-d texture memory on GPUs) will require
// more in-depth algorithms.
bool is_flat_memory_space = (num_physical_dimensions == 1);

// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (scope.tag.length() == 0) {
if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (const_nbits > 0 && const_nbits <= 32) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
bool is_small_array =
(scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
(is_known_size && const_nbits <= 32));

if (is_small_array || !is_flat_memory_space) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (const_nbits != 0) {

if (is_known_size) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);
Expand Down