diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 934ce0c5ec9f..8b5ad9dab80b 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -178,6 +178,14 @@ constexpr const char* pragma_scope = "pragma_scope"; * run prefetch of Tensor on the current loop scope */ constexpr const char* prefetch_scope = "prefetch_scope"; +/*! + * \brief Marks production of double buffer data + */ +constexpr const char* double_buffer_scope = "double_buffer_scope"; +/*! + * \brief Marks region used by double buffer write + */ +constexpr const char* double_buffer_write = "double_buffer_write"; /*! \brief Mark of scan update scope */ constexpr const char* scan_update_scope = "scan_update_scope"; /*! \brief Mark of scan init scope */ diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 4dd3fd82ede7..1ac3887a489c 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -231,6 +231,14 @@ Stmt InjectVirtualThread(Stmt stmt); */ Stmt InjectPrefetch(Stmt stmt); +/*! + * \brief Inject double buffer into stmt. + * \param stmt The statment to be transformed. + * \param split_loop Whether split the loop containing double buffering. + * \return Transformed stmt. + */ +Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop); + /*! * \brief Rewrite storage allocation pattern. * Moves the allocation to outer most possible scope. diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 0f846381314f..aeb5ffa66f93 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -208,6 +208,11 @@ class Stage : public NodeRef { * \return reference to self */ Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*) + /*! + * \brief Compute current stage with double buffering. + * \return reference to self. + */ + Stage& double_buffer(); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -408,6 +413,8 @@ class StageNode : public Node { std::string scope; /*! \brief Whether this is an output stage */ bool is_output{false}; + /*! \brief Whether apply double buffer optimization to this stage */ + bool double_buffer{false}; /*! * \brief The parent group of the current stage. * The stage cannot be assigned to stages outside the group. @@ -429,6 +436,7 @@ class StageNode : public Node { v->Visit("attach_stage", &attach_stage); v->Visit("scope", &scope); v->Visit("is_output", &is_output); + v->Visit("double_buffer", &double_buffer); v->Visit("group", &group); v->Visit("num_child_stages", &num_child_stages); } diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index e14d6f5ba848..09cc7747bc55 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -33,6 +33,7 @@ class BuildConfig(object): "offset_factor": 0, "data_alignment": -1, "restricted_func": True, + "double_buffer_split_loop": True, "add_lower_pass": None } def __init__(self, **kwargs): @@ -97,6 +98,10 @@ def build_config(**kwargs): not to overlap. This enables more optimization. Corresponds to restricted keyword in C99 + double_buffer_split_loop: bool, default=True + Whether split the loop containing double buffer so + that the buffer fetching won't contain condition. + add_lower_pass: list of function(Stmt->Stmt), default=None Additional lowering passes to be applied before make_api. @@ -187,6 +192,7 @@ def lower(sch, Then the Stmt before make api is returned. """ binds, arg_list = get_binds(args, binds) + cfg = BuildConfig.current # normalize schedule first sch = sch.normalize() bounds = schedule.InferBound(sch) @@ -198,8 +204,8 @@ def lower(sch, stmt = ir_pass.LoopPartition(stmt) stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) + stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) - cfg = BuildConfig.current stmt = ir_pass.UnrollLoop( stmt, cfg.auto_unroll_max_step, diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index defd975bc917..1888cd9f1d18 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -268,6 +268,21 @@ def _exit_cb(): self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq())) return WithScope(None, _exit_cb) + def new_scope(self): + """Create new scope, + + this is useful to set boundary of attr and allocate. + + Returns + ------- + new_scope : WithScope + The result new scope. + """ + self._seq_stack.append([]) + def _exit_cb(): + self.emit(self._pop_seq()) + return WithScope(None, _exit_cb) + def allocate(self, dtype, shape, name="buf", scope=None): """Create a allocate statement. diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 6dca90a7fddb..ecaeb50bc6e5 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -589,4 +589,13 @@ def storage_align(self, axis, factor, offset): """ _api_internal._StageStorageAlign(self, axis, factor, offset) + def double_buffer(self): + """Compute the current stage via double buffering. + + This can only be applied to intermediate stage. + This will double the storage cost of the current stage. + Can be useful to hide load latency. + """ + _api_internal._StageDoubleBuffer(self) + _init_api("tvm.schedule") diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index dda1bc5f56c5..50531d73010f 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -385,13 +385,18 @@ TVM_REGISTER_API("_StagePragma") TVM_REGISTER_API("_StagePrefetch") .set_body([](TVMArgs args, TVMRetValue *ret) { args[0].operator Stage() - .prefetch(args[1], args[2], args[3]); + .prefetch(args[1], args[2], args[3]); }); TVM_REGISTER_API("_StageStorageAlign") .set_body([](TVMArgs args, TVMRetValue *ret) { args[0].operator Stage() - .storage_align(args[1], args[2], args[3]); + .storage_align(args[1], args[2], args[3]); + }); + +TVM_REGISTER_API("_StageDoubleBuffer") + .set_body([](TVMArgs args, TVMRetValue *ret) { + args[0].operator Stage().double_buffer(); }); TVM_REGISTER_API("_ScheduleNormalize") diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index e5505b68df64..644ace9d9855 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -101,6 +101,7 @@ REGISTER_PASS1(CoProcSync); REGISTER_PASS1(LowerStorageAccessInfo); REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectPrefetch); +REGISTER_PASS2(InjectDoubleBuffer); REGISTER_PASS1(LoopPartition); REGISTER_PASS1(RemoveNoOp); REGISTER_PASS2(SplitPipeline); diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc new file mode 100644 index 000000000000..f8ae5c0301e2 --- /dev/null +++ b/src/pass/inject_double_buffer.cc @@ -0,0 +1,226 @@ +/*! + * Copyright (c) 2017 by Contributors + * + * \brief Inject double buffering optimization for data fetch. + * \file inject_double_buffer.cc + */ +#include +#include +#include +#include "./ir_util.h" +#include "../arithmetic/compute_expr.h" + +namespace tvm { +namespace ir { + +// Detect double buffer variables. +class DoubleBufferDetector : public IRVisitor { + public: + void Visit_(const AttrStmt* op) final { + if (op->attr_key == attr::double_buffer_scope) { + touched_.insert(op->node.as()); + IRVisitor::Visit_(op); + } else { + IRVisitor::Visit_(op); + } + } + + void Visit_(const Variable* op) final { + if (touched_.count(op)) { + touched_.erase(op); + } + } + // The set of touched variable. + std::unordered_set touched_; +}; + +class DoubleBufferInjector : public IRMutator { + public: + explicit DoubleBufferInjector(bool split_loop) + : split_loop_(split_loop) {} + + Stmt Inject(const Stmt& stmt) { + DoubleBufferDetector detector; + detector.Visit(stmt); + if (detector.touched_.empty()) return stmt; + for (const Variable* v : detector.touched_) { + dbuffer_info_[v] = StorageEntry(); + } + return ConvertSSA(this->Mutate(stmt)); + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->attr_key == attr::storage_scope) { + const Variable* buf = op->node.as(); + auto it = dbuffer_info_.find(buf); + if (it != dbuffer_info_.end()) { + it->second.scope = op->value.as()->value; + return Mutate(op->body); + } else { + return IRMutator::Mutate_(op, s); + } + } else if (op->attr_key == attr::double_buffer_scope) { + return MakeProducer(op, s); + } else { + return IRMutator::Mutate_(op, s); + } + } + + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + auto it = dbuffer_info_.find(op->buffer_var.get()); + if (it != dbuffer_info_.end()) { + it->second.size = arith::ComputeReduce(op->extents); + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + Array new_extents{make_const(op->extents[0].type(), 2)}; + for (Expr e : op->extents) { + new_extents.push_back(e); + } + CHECK(it->second.loop != nullptr); + auto& alloc_nest = loop_allocs_[it->second.loop]; + alloc_nest.emplace_back(AttrStmt::make( + op->buffer_var, attr::storage_scope, + StringImm::make(it->second.scope), + Evaluate::make(0))); + alloc_nest.emplace_back(Allocate::make( + op->buffer_var, op->type, new_extents, op->condition, + Evaluate::make(0))); + return op->body; + } else { + return IRMutator::Mutate_(op, s); + } + } + + Stmt Mutate_(const For* op, const Stmt& s) final { + loop_nest_.push_back(op); + Stmt stmt = IRMutator::Mutate_(op, s); + auto it = loop_pre_.find(op); + if (it != loop_pre_.end()) { + const For* old_loop = stmt.as(); + if (split_loop_) { + Expr new_ext = arith::ComputeExpr( + old_loop->extent, make_const(old_loop->loop_var.type(), 1)); + Stmt loop = For::make( + old_loop->loop_var, old_loop->min, new_ext, + old_loop->for_type, old_loop->device_api, + old_loop->body); + std::unordered_map vmap; + vmap[old_loop->loop_var.get()] = new_ext; + Stmt end = Substitute(old_loop->body, vmap); + stmt = Block::make(loop, end); + } + stmt = Block::make(MergeSeq(it->second), stmt); + } + it = loop_allocs_.find(op); + if (it != loop_allocs_.end()) { + stmt = MergeNest(it->second, stmt); + } + loop_nest_.pop_back(); + return stmt; + } + + Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + auto it = dbuffer_info_.find(op->buffer_var.get()); + if (it != dbuffer_info_.end()) { + const StorageEntry& e = it->second; + CHECK(in_double_buffer_scope_); + CHECK(e.size.defined()); + return Store::make(op->buffer_var, + op->value, + e.switch_write_var * e.size + op->index, + op->predicate); + } else { + return stmt; + } + } + + Expr Mutate_(const Load* op, const Expr& e) final { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + auto it = dbuffer_info_.find(op->buffer_var.get()); + if (it != dbuffer_info_.end()) { + const StorageEntry& e = it->second; + CHECK(e.size.defined()); + CHECK(e.switch_read_var.defined()); + return Load::make(op->type, + op->buffer_var, + e.switch_read_var * e.size + op->index, + op->predicate); + } else { + return expr; + } + } + + Expr Mutate_(const Variable* op, const Expr& e) final { + CHECK(!dbuffer_info_.count(op)); + return e; + } + + private: + Stmt MakeProducer(const AttrStmt* op, const Stmt& s) { + const VarExpr buffer(op->node.node_); + CHECK_NE(loop_nest_.size(), 0U) + << "Double buffer scope must be inside a loop"; + auto it = dbuffer_info_.find(buffer.get()); + if (it == dbuffer_info_.end()) { + LOG(WARNING) << "Skip double buffer scope " << op->node; + return Mutate(op->body); + } + StorageEntry& e = it->second; + e.loop = loop_nest_.back(); + Expr zero = make_const(e.loop->loop_var.type(), 0); + Expr one = make_const(e.loop->loop_var.type(), 1); + Expr two = make_const(e.loop->loop_var.type(), 2); + Expr loop_shift = e.loop->loop_var + one; + e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", + e.loop->loop_var.type()); + e.switch_read_var = e.loop->loop_var % two; + in_double_buffer_scope_ = true; + Stmt body = Mutate(op->body); + in_double_buffer_scope_ = false; + std::unordered_map vmap; + vmap[e.switch_write_var.get()] = zero; + vmap[e.loop->loop_var.get()] = zero; + loop_pre_[e.loop].emplace_back(Substitute(body, vmap)); + vmap[e.loop->loop_var.get()] = loop_shift; + vmap[e.switch_write_var.get()] = loop_shift % two; + body = Substitute(body, vmap); + body = AttrStmt::make(buffer, attr::double_buffer_write, 1, body); + body = IfThenElse::make(loop_shift < e.loop->extent, body); + return body; + } + // Storage entry for those who need double buffering. + struct StorageEntry { + // The size of the buffer + Expr size; + // The loop we need + const For* loop{nullptr}; + // The switch variable. + VarExpr switch_write_var; + // The switch variable for reading. + Expr switch_read_var; + // The storage scope. + std::string scope; + }; + // Whether split loop + bool split_loop_; + // Whether we are inside double buffer scope. + bool in_double_buffer_scope_{false}; + // The current loop next + std::vector loop_nest_; + // The allocs to be appended before the loop + std::unordered_map > loop_allocs_; + // The stmt to be appended before the loop + std::unordered_map > loop_pre_; + // The allocation size of the buffer + std::unordered_map dbuffer_info_; +}; + + +Stmt InjectDoubleBuffer(Stmt stmt, bool split_loop) { + return DoubleBufferInjector(split_loop).Inject(stmt); +} +} // namespace ir +} // namespace tvm diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index 399d92133f74..9211f3f71de0 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -74,6 +74,24 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { storage_scope_[buf] = StorageScope::make(op->value.as()->value); IRVisitor::Visit_(op); + } else if (op->attr_key == attr::double_buffer_write) { + CHECK(double_buffer_write_ == nullptr); + double_buffer_write_ = op->node.as(); + scope_.push_back(std::vector()); + IRVisitor::Visit_(op); + StmtEntry s; + s.stmt = op; + s.access = Summarize(std::move(scope_.back()), nullptr); + scope_.pop_back(); + if (!s.access.empty()) { + for (AccessEntry& e : s.access) { + if (e.type == kWrite && e.buffer.get() == double_buffer_write_) { + e.double_buffer_write = true; + } + } + scope_.back().emplace_back(std::move(s)); + } + double_buffer_write_ = nullptr; } else if (op->attr_key == attr::coproc_scope) { IterVar iv(op->node.node_); env_threads_.push_back(iv); diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h index 9e40e75223a5..7268bb668342 100644 --- a/src/pass/storage_access.h +++ b/src/pass/storage_access.h @@ -45,6 +45,8 @@ class StorageAccessVisitor : public IRVisitor { AccessType type; /*! \brief The storage scope */ StorageScope scope; + /*! \brief Whether the access is double buffer write */ + bool double_buffer_write{false}; }; /*! \brief Access pattern about a single statement */ struct StmtEntry { @@ -116,6 +118,8 @@ class StorageAccessVisitor : public IRVisitor { bool in_device_env_{false}; // Whether we are inside condition. int condition_counter_{0}; + // The current double buffer write scope. + const Variable* double_buffer_write_{nullptr}; // the current free stmt entry. StmtEntry curr_stmt_; // The involving threads diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 551442486e95..30b09a6da520 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -4,6 +4,7 @@ */ #include #include +#include #include #include #include @@ -53,6 +54,18 @@ class StorageFlattener : public IRMutator { if (op->attr_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; return this->Mutate(op->body); + } else if (op->attr_key == attr::double_buffer_scope) { + Operation func(op->node.node_); + Stmt body = Mutate(op->body); + for (int i = 0; i < func->num_outputs(); ++i) { + TensorKey key{func, i}; + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) + << "Cannot find allocated buffer for " << key.f; + body = AttrStmt::make( + it->second.buffer->data, op->attr_key, op->value, body); + } + return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv(op->node.node_); ThreadScope ts = ThreadScope::make(iv->thread_tag); diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 13773321edac..af3dc1f128e5 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -34,13 +34,10 @@ class ThreadSyncPlanner : public StorageAccessVisitor { // Unsynced reads and writes std::vector reads; std::vector writes; - // if it is a loop, rotate two times to consider effect of loop. - size_t max_seq = seq.size(); - if (loop != nullptr) max_seq *= 2; // simulation based approach to find dependenceies - for (size_t i = 0; i < max_seq; ++i) { - const StmtEntry& s = seq[i % seq.size()]; + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry& s = seq[i]; // check if sync before statement is needed. bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); // Apply the syncs added already. @@ -50,11 +47,11 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { - if (FindConflict(writes, acc)) { + if (FindConflict(writes, acc, false)) { sync_before_stmt = true; break; } } else if (acc.type == kWrite) { - if (FindConflict(reads, acc)) { + if (FindConflict(reads, acc, false)) { sync_before_stmt = true; break; } } else if (acc.type == kSync) { @@ -81,6 +78,33 @@ class ThreadSyncPlanner : public StorageAccessVisitor { syncs_inserted_.insert(s.stmt); } } + if (loop != nullptr) { + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry& s = seq[i]; + if (syncs_inserted_.count(s.stmt) != 0) break; + if (reads.empty() && writes.empty()) break; + bool sync_before_stmt = false; + for (const AccessEntry& acc : s.access) { + if (acc.type == kRead) { + if (FindConflict(writes, acc, true)) { + sync_before_stmt = true; break; + } + } else if (acc.type == kWrite) { + if (FindConflict(reads, acc, true)) { + sync_before_stmt = true; break; + } + } else if (acc.type == kSync) { + reads.clear(); writes.clear(); + } + } + if (sync_before_stmt) { + CHECK_EQ(condition_counter(), 0) + << "Cannot insert syncs inside condition"; + syncs_inserted_.insert(s.stmt); + break; + } + } + } // return the exposed entries, remove unecessary ones. int sync_count = 0; // head are before first sync, tail are after last sync @@ -117,13 +141,20 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } } head.insert(head.end(), tail.begin(), tail.end()); + if (loop != nullptr) { + // clear double buffer flag after a loop is finished. + for (AccessEntry& e : head) { + e.double_buffer_write = false; + } + } return head; } private: // find conflicting entry in vec. bool FindConflict(const std::vector& vec, - const AccessEntry& e) { + const AccessEntry& e, + bool loop_carry) { for (const AccessEntry& x : vec) { if (x.buffer.same_as(e.buffer)) { // Assumes no race between threads @@ -134,6 +165,9 @@ class ThreadSyncPlanner : public StorageAccessVisitor { if (Equal(e.touched.point_value(), x.touched.point_value())) continue; } + if (x.double_buffer_write && + e.type == kRead && + !loop_carry) continue; return true; } } diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index c07f60dccceb..7f979e608cdf 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -385,6 +385,13 @@ Stage& Stage::storage_align(IterVar axis, int factor, int offset) { return *this; } +Stage& Stage::double_buffer() { + StageNode *self = operator->(); + CHECK(!self->is_output) << "Cannot apply double buffer on output"; + self->double_buffer = true; + return *this; +} + Stage CopyStage(const Stage& s) { std::shared_ptr n = std::make_shared(*s.operator->()); diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 724672d2a951..875df556466a 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -27,6 +27,10 @@ Stmt MakePipeline(const Stage& s, if (producer.defined()) { producer = ProducerConsumer::make(s->op, true, producer); } + if (s->double_buffer) { + producer = AttrStmt::make( + s->op, ir::attr::double_buffer_scope, 1, producer); + } Stmt pipeline = producer; if (consumer.defined() && !is_no_op(consumer)) { @@ -170,7 +174,8 @@ class SchedulePostProc : public IRMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == ir::attr::realize_scope) { + } else if (op->attr_key == ir::attr::realize_scope || + op->attr_key == ir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index 3c190f09e7e3..0798ecf61e4f 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -47,7 +47,8 @@ def test_gemm(): s[CC].compute_at(s[C], tx) s[AA].compute_at(s[CC], k) s[BB].compute_at(s[CC], k) - + s[AA].double_buffer() + s[BB].double_buffer() ty, xi = s[AA].split(s[AA].op.axis[0], nparts=num_thread) tx, xi = s[AA].split(xi, nparts=num_thread) s[AA].bind(ty, thread_y) @@ -84,10 +85,10 @@ def check_device(device): np.testing.assert_allclose( c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) - check_device("nvptx -mcpu=sm_20") check_device("metal") check_device("opencl") check_device("cuda") + #check_device("nvptx -mcpu=sm_20") if __name__ == "__main__": test_gemm() diff --git a/tests/python/unittest/test_pass_inject_double_buffer.py b/tests/python/unittest/test_pass_inject_double_buffer.py new file mode 100644 index 000000000000..133ba7f7e17e --- /dev/null +++ b/tests/python/unittest/test_pass_inject_double_buffer.py @@ -0,0 +1,37 @@ +import tvm + +def test_double_buffer(): + dtype = 'int64' + n = 100 + m = 4 + tx = tvm.thread_axis("threadIdx.x") + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="A") + ib.scope_attr(tx, "thread_extent", 1) + with ib.for_range(0, n) as i: + B = ib.allocate("float32", m, name="B", scope="shared") + with ib.new_scope(): + ib.scope_attr(B.asnode(), "double_buffer_scope", 1) + with ib.for_range(0, m) as j: + B[j] = A[i * 4 + j] + with ib.for_range(0, m) as j: + C[j] = B[j] + 1 + + stmt = ib.get() + stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, True) + stmt = tvm.ir_pass.Simplify(stmt) + assert isinstance(stmt.body.body, tvm.stmt.Allocate) + assert stmt.body.body.extents[0].value == 2 + f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) + f = tvm.ir_pass.ThreadSync(f, "shared") + count = [0] + def count_sync(op): + if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync": + count[0] += 1 + tvm.ir_pass.PostOrderVisit(f.body, count_sync) + assert count[0] == 2 + + +if __name__ == "__main__": + test_double_buffer() diff --git a/topi/recipe/gemm/cuda_gemm_square.py b/topi/recipe/gemm/cuda_gemm_square.py index a8417a1e3426..f27d6a74d883 100644 --- a/topi/recipe/gemm/cuda_gemm_square.py +++ b/topi/recipe/gemm/cuda_gemm_square.py @@ -96,6 +96,8 @@ def test_gemm(): s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s[BB].vectorize(xi) + s[AA].double_buffer() + s[BB].double_buffer() # correctness def check_device(device): if not tvm.module.enabled(device):