Skip to content

Commit

Permalink
[tir] Add line level debug info (#13012)
Browse files Browse the repository at this point in the history
* TIR debug info

* Fix location emission

* Comments 1/N (docs, cleanups)

* Remove leaky macro usage

* Add unit test

* Remove dead code

* Add accuracy test

Co-authored-by: driazati <driazati@users.noreply.github.com>
  • Loading branch information
driazati and driazati authored Jan 6, 2023
1 parent 21d7968 commit 123f1f5
Show file tree
Hide file tree
Showing 16 changed files with 762 additions and 89 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,9 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py

# Used in CI to communicate between Python and Jenkins
.docker-image-names/

# Printed TIR code on disk
*.tir

# GDB history file
.gdb_history
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,13 @@ TVM_DLL Pass LowerAsyncDMA();
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);

/*!
* \brief Add TIR-printer output as debug information to all ops in the module
* \return The pass.
*/

TVM_DLL Pass InstallDebugSpans();

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,3 +1028,15 @@ def InstrumentProfileIntrinsics():
The result pass
"""
return _ffi_api.InstrumentProfileIntrinsics() # type: ignore


def InstallDebugSpans():
"""Add line information from the TIR printer as spans on each statement and
expression.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InstallDebugSpans() # type: ignore
8 changes: 8 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
Expand Down Expand Up @@ -603,6 +604,9 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
});

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();

Array<tvm::transform::Pass> host_pass_list;

runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
Expand All @@ -621,6 +625,10 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
host_pass_list.push_back(tir::transform::CombineContextCall());

if (enable_debug) {
host_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(host_pass_list);
}

Expand Down
1 change: 1 addition & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ Pass GetPass(const String& pass_name) {
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
for (const Pass& pass : passes) {
VLOG(0) << "Running pass " << pass->Info()->name;
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!pass_ctx.PassEnabled(pass_info)) {
Expand Down
40 changes: 22 additions & 18 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
: show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}

/*! \brief Output a newline */
virtual Doc NewLine();

/*! \brief Print the node */
Doc Print(const ObjectRef& node);

Expand All @@ -290,24 +293,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
*/
bool GetVarName(::tvm::tir::Var v, std::string* s);

private:
/*! \brief whether show meta data */
bool show_meta_;
/*! \brief meta data context */
TextMetaDataContext* meta_;
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief Map from Var to Doc */
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

friend class tvm::TextPrinter;

protected:
Doc VisitExpr_(const IntImmNode* op) override;
Doc VisitExpr_(const FloatImmNode* op) override;
Doc VisitExpr_(const StringImmNode* op) override;
Expand Down Expand Up @@ -363,6 +349,24 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const BlockRealizeNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;

private:
/*! \brief whether show meta data */
bool show_meta_;
/*! \brief meta data context */
TextMetaDataContext* meta_;
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief Map from Var to Doc */
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

friend class tvm::TextPrinter;

Doc VisitType_(const PrimTypeNode* node) override;
Doc VisitType_(const PointerTypeNode* node) override;
Doc VisitType_(const TupleTypeNode* node) override;
Expand Down
53 changes: 27 additions & 26 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
for (const auto& it : op->attrs->dict) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
attr_doc << NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
doc << Doc::Indent(2, attr_doc);
}

Expand All @@ -136,8 +136,8 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
const Buffer buf = op->buffer_map[v];
buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf)));
}
buffer_doc << Doc::NewLine() << "buffers = {";
buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine()));
buffer_doc << NewLine() << "buffers = {";
buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << NewLine()));
doc << Doc::Indent(2, buffer_doc) << "}";
}

Expand All @@ -149,26 +149,28 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
buffer_map_doc.push_back(Print(v) << ": " << Print(buf));
}
doc << Doc::Indent(
2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
2, NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
}

doc << PrintBody(op->body);
return doc;
}

Doc TIRTextPrinter::NewLine() { return Doc::NewLine(); }

Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
const auto* op = module.operator->();
Doc doc;

Doc body;
body << Doc::NewLine();
body << NewLine();
std::vector<Doc> functions;
for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
if ((*it).second.as<PrimFuncNode>()) {
functions.push_back(Print((*it).second));
}
}
body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
body << TIRTextPrinter::PrintSep(functions, NewLine() << NewLine());
doc << Doc::Indent(0, body);
return doc;
}
Expand Down Expand Up @@ -451,7 +453,7 @@ Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {

Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body);
doc << "let " << Print(op->var) << " = " << Print(op->value) << NewLine() << Print(op->body);
return doc;
}

Expand All @@ -463,14 +465,14 @@ Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
Doc doc;
doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine()
doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << NewLine()
<< Print(op->body);
return doc;
}
Expand Down Expand Up @@ -529,7 +531,7 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}
Expand All @@ -542,19 +544,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
Doc doc;
doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", "
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine();
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << NewLine();
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}
Expand All @@ -572,9 +574,9 @@ Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) {
std::vector<Doc> stmts;
Doc seq_doc, doc;
for (Stmt stmt : op->seq) {
seq_doc << Doc::NewLine() << Print(stmt);
seq_doc << NewLine() << Print(stmt);
}
doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}";
doc << " {" << Doc::Indent(2, seq_doc) << NewLine() << "}";
return doc;
}

Expand Down Expand Up @@ -657,37 +659,36 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
Doc block_attr_doc;
// print predicate, binding, read/write tensor region, annotations
if (!is_one(op->predicate)) {
block_attr_doc << Doc::NewLine() << "where(" << Print(op->predicate) << ")";
block_attr_doc << NewLine() << "where(" << Print(op->predicate) << ")";
}
for (size_t i = 0; i < block_op->iter_vars.size(); ++i)
block_attr_doc << Doc::NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
block_attr_doc << NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
<< Print(op->iter_values[i]) << ")";
block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
block_attr_doc << NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
block_attr_doc << NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
if (!block_op->annotations.empty()) {
std::vector<Doc> attr_docs;
for (const auto& it : block_op->annotations) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
block_attr_doc << Doc::NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", "))
<< "})";
block_attr_doc << NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
}
// print body
Doc body;
body << Doc::NewLine();
body << NewLine();
for (const auto& alloc_buf : block_op->alloc_buffers) {
body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype)
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
<< Print(alloc_buf->shape) << ")" << NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
<< NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
init_block << "with init()";
init_block << PrintBody(block_op->init.value());
body << init_block << Doc::NewLine();
body << init_block << NewLine();
}
body << Print(block_op->body);
doc << Doc::Indent(2, block_attr_doc << body);
Expand Down Expand Up @@ -826,7 +827,7 @@ Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
Doc doc;
if (body->IsInstance<SeqStmtNode>()) return Print(body);
doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
doc << " {" << Doc::Indent(2, NewLine() << Print(body)) << NewLine() << "}";
return doc;
}

Expand Down
Loading

0 comments on commit 123f1f5

Please sign in to comment.