Skip to content

Commit

Permalink
thread storage scope through pipeline to buffer creation
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 29, 2021
1 parent 496a215 commit d39a470
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 37 deletions.
14 changes: 7 additions & 7 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class TVM_DLL OperationNode : public Object {
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
const std::unordered_map<IterVar, Range>& realize_map, const Stmt& body,
String storage_scope = "") const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
Expand Down Expand Up @@ -168,7 +168,7 @@ class PlaceholderOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -212,7 +212,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
virtual size_t num_schedulable_dims() const = 0;

static constexpr const char* _type_key = "BaseComputeOp";
Expand Down Expand Up @@ -370,7 +370,7 @@ class ScanOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -433,7 +433,7 @@ class ExternOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -498,7 +498,7 @@ class HybridOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self,

Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
const Stmt& body, String storage_scope) const {
ICHECK_EQ(stage->op.get(), this);
Region bounds;
for (IterVar iv : this->axis) {
Expand All @@ -269,7 +269,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i - 1);
realize = tir::ProducerRealize(t, bounds, const_true(), realize);
realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/extern_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void ExternOpNode::GatherBound(const Operation& self,

Stmt ExternOpNode::BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
const Stmt& body, String storage_scope) const {
ICHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
for (int k = 0; k < num_outputs(); ++k) {
Expand All @@ -133,7 +133,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage,
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope);
}
return realize_body;
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void HybridOpNode::GatherBound(const Operation& self,

Stmt HybridOpNode::BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
const Stmt& body, String storage_scope) const {
// TODO(@were): Add attribute inject here and remove it from hybrid parser.
ICHECK_EQ(stage->op.get(), this);
Stmt realize_body = body;
Expand All @@ -154,7 +154,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage,
for (size_t i = 0; i < t->shape.size(); ++i) {
bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
}
realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope);
}
return realize_body;
}
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/placeholder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void PlaceholderOpNode::GatherBound(const Operation& self,

Stmt PlaceholderOpNode::BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const {
const Stmt& body, String storage_scope) const {
return body;
}

Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/scan_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ void ScanOpNode::GatherBound(const Operation& self,
}

Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
const Stmt& body) const {
const Stmt& body, String storage_scope) const {
arith::Analyzer analyzer;
ICHECK_EQ(stage->op.get(), this);
Range sdom = dom_map.at(this->scan_axis);
Expand All @@ -250,7 +250,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterV
IterVar sp_ax = this->spatial_axis_[sp_idx];
bounds.push_back(dom_map.at(sp_ax));
}
ret = tir::ProducerRealize(t, bounds, const_true(), ret);
ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope);
}
return ret;
}
Expand Down
2 changes: 1 addition & 1 deletion src/te/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_
if (consumer.defined() && !is_no_op(consumer)) {
pipeline = SeqStmt({producer, consumer});
}
pipeline = s->op->BuildRealize(s, dom_map, pipeline);
pipeline = s->op->BuildRealize(s, dom_map, pipeline, s->scope);
// use attribute to mark scope of the operation.
pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline);

Expand Down
14 changes: 8 additions & 6 deletions src/te/schedule/schedule_postproc_to_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ namespace tvm {
namespace te {

// create a buffer for tensor.
Buffer CreateBufferFor(const Tensor& tensor) {
Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") {
std::string name = tensor->op->name;
if (tensor->op->num_outputs() != 1) {
name += ".v" + std::to_string(tensor->value_index);
}
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name);
Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope);
return buffer;
}

Expand Down Expand Up @@ -95,7 +95,7 @@ class TensorToBufferMapper : public StmtExprMutator {

Stmt VisitStmt_(const ProducerRealizeNode* op) final {
Tensor tensor = Downcast<Tensor>(op->producer);
Buffer buffer = GetOrAllocBuffer(tensor);
Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope);

auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<ProducerRealizeNode>();
Expand All @@ -122,14 +122,16 @@ class TensorToBufferMapper : public StmtExprMutator {
}

private:
Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); }
Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") {
return GetBuffer(tensor, storage_scope, true);
}

Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) {
Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) {
auto it = buffer_map_.find(tensor);
if (it != buffer_map_.end()) return it->second;
ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor;

auto buffer = CreateBufferFor(tensor);
auto buffer = CreateBufferFor(tensor, storage_scope);
buffer_map_[tensor] = buffer;
return buffer;
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String st
String GetStorageScope(Var buffer_var) {
auto type = buffer_var->type_annotation;
const auto* ptr_type = type.as<PointerTypeNode>();
ICHECK(ptr_type);
ICHECK(ptr_type) << "The provided variable is not of pointer type";
return ptr_type->storage_scope;
}

Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr
TVM_REGISTER_GLOBAL("tir.ProducerRealize")
.set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
Span span) {
return ProducerRealize(producer, bounds, condition, body, span);
return ProducerRealize(producer, bounds, condition, body, "", span);
});

TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
Expand Down
22 changes: 10 additions & 12 deletions src/tir/transforms/thread_storage_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
Expand Down Expand Up @@ -223,14 +224,14 @@ class ThreadSyncInserter : public StmtExprMutator {
}
PrimExpr VisitExpr_(const LoadNode* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
GetScope(op->buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].read_count;
}
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const StoreNode* op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) {
GetScope(op->buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer_var].write_count;
}
return StmtExprMutator::VisitStmt_(op);
Expand Down Expand Up @@ -264,16 +265,15 @@ class ThreadSyncInserter : public StmtExprMutator {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
ICHECK_EQ(op->args.size(), 5U);
const VarNode* buffer_var = op->args[1].as<VarNode>();
Var var(GetRef<Var>(buffer_var));
Var buffer_var(GetRef<Var>(op->args[1].as<VarNode>()));
const IntImmNode* flag = op->args[4].as<IntImmNode>();
if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[var].read_count;
++rw_stats_[buffer_var].read_count;
}
if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[var].write_count;
++rw_stats_[buffer_var].write_count;
}
return expr;
} else {
Expand All @@ -287,14 +287,12 @@ class ThreadSyncInserter : public StmtExprMutator {
int read_count{0};
int write_count{0};
};

// Get current storage scope.
StorageScope GetScope(const VarNode* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s;
s.rank = StorageRank::kGlobal;
if (it == storage_scope_.end()) return s;
return it->second;
StorageScope GetScope(Var buffer_var) const {
return StorageScope::Create(GetStorageScope(buffer_var));
}

// private functions.
Stmt InitGlobalBarrier(const AttrStmtNode* op) {
ICHECK(op != nullptr);
Expand Down

0 comments on commit d39a470

Please sign in to comment.