Skip to content

Commit

Permalink
remove storage_scope map from storage_access.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 29, 2021
1 parent 0ba5c71 commit fd07b35
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
20 changes: 6 additions & 14 deletions src/tir/transforms/storage_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace tir {

void StorageAccessVisitor::VisitExpr_(const LoadNode* op) {
const VarNode* buf = op->buffer_var.as<VarNode>();
StorageScope scope = GetScope(buf);
StorageScope scope = GetScope(op->buffer_var);
if (Enabled(buf, scope)) {
ICHECK(allow_append_) << op << " " << scope.to_string();
AccessEntry e;
Expand All @@ -56,7 +56,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) {
ICHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
const VarNode* buf = op->buffer_var.as<VarNode>();
StorageScope scope = GetScope(buf);
StorageScope scope = GetScope(op->buffer_var);
if (Enabled(buf, scope)) {
AccessEntry e;
e.threads = env_threads();
Expand Down Expand Up @@ -90,11 +90,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) {
}

void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
storage_scope_[buf] = StorageScope::Create(op->value.as<StringImmNode>()->value);
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::double_buffer_write) {
if (op->attr_key == attr::double_buffer_write) {
ICHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<VarNode>();
scope_.push_back(std::vector<StmtEntry>());
Expand Down Expand Up @@ -208,7 +204,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
PrimExpr offset = op->args[2];
PrimExpr extent = op->args[3];
const IntImmNode* flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(buffer);
StorageScope scope = GetScope(GetRef<Var>(buffer));
// The buffer scope.
if (Enabled(buffer, scope)) {
ICHECK(allow_append_);
Expand Down Expand Up @@ -244,12 +240,8 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
}
}

StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s;
s.rank = StorageRank::kGlobal;
if (it == storage_scope_.end()) return s;
return it->second;
StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const {
return StorageScope::Create(GetStorageScope(buffer_var));
}

} // namespace tir
Expand Down
4 changes: 1 addition & 3 deletions src/tir/transforms/storage_access.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor {
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
*/
StorageScope GetScope(const VarNode* buf) const;
StorageScope GetScope(Var buffer_var) const;
// access scope
std::vector<std::vector<StmtEntry> > scope_;

Expand All @@ -135,8 +135,6 @@ class StorageAccessVisitor : public StmtExprVisitor {
StmtEntry curr_stmt_;
// The involving threads
Array<IterVar> env_threads_;
// The storage scope of each buffer
std::unordered_map<const VarNode*, StorageScope> storage_scope_;
};

} // namespace tir
Expand Down

0 comments on commit fd07b35

Please sign in to comment.