diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 6504e2c2843d..6321caa4e057 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -1194,6 +1194,50 @@ class ClassDoc : public StmtDoc { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode); }; +/*! + * \brief Doc that represents comment. + * + * \sa CommentDoc + */ +class CommentDocNode : public StmtDocNode { + public: + static constexpr const char* _type_key = "script.printer.CommentDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(CommentDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of CommentDocNode. + * + * \sa CommentDocNode + */ +class CommentDoc : public StmtDoc { + public: + explicit CommentDoc(String comment); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CommentDoc, StmtDoc, CommentDocNode); +}; + +/*! + * \brief Doc that represents docstring. + * + * \sa DocStringDoc + */ +class DocStringDocNode : public StmtDocNode { + public: + static constexpr const char* _type_key = "script.printer.DocStringDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(DocStringDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of DocStringDocNode. + * + * \sa DocStringDocNode + */ +class DocStringDoc : public StmtDoc { + public: + explicit DocStringDoc(String docs); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DocStringDoc, StmtDoc, DocStringDocNode); +}; + } // namespace printer } // namespace script } // namespace tvm diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 67fa96ef8082..c41827fe9530 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -148,6 +149,8 @@ class IRDocsifierNode : public Object { std::unordered_set defined_names; /*! \brief Common prefixes of variable usages */ std::unordered_map> common_prefix; + /*! \brief The IR usages for headers printing */ + std::unordered_set ir_usage; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("frames", &frames); @@ -156,6 +159,7 @@ class IRDocsifierNode : public Object { // `obj2info` is not visited // `defined_names` is not visited // `common_prefix` is not visited + // `ir_usage` is not visited } static constexpr const char* _type_key = "script.printer.IRDocsifier"; diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 5a4a4cd67a72..9a6e7f1b8c8f 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -521,3 +521,23 @@ def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]): decorators, body, ) + + +@register_object("script.printer.CommentDoc") +class CommentDoc(StmtDoc): + """Doc that represents comment.""" + + def __init__(self, comment: str): + self.__init_handle_by_constructor__( + _ffi_api.CommentDoc, comment # type: ignore # pylint: disable=no-member + ) + + +@register_object("script.printer.DocStringDoc") +class DocStringDoc(StmtDoc): + """Doc that represents docstring.""" + + def __init__(self, docs: str): + self.__init_handle_by_constructor__( + _ffi_api.DocStringDoc, docs # type: ignore # pylint: disable=no-member + ) diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 89f6b7c8b1cf..1db4e090dcff 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -221,6 +221,18 @@ ClassDoc::ClassDoc(IdDoc name, Array decorators, Array body) { this->data_ = std::move(n); } +CommentDoc::CommentDoc(String comment) { + ObjectPtr n = make_object(); + n->comment = comment; + this->data_ = std::move(n); +} + +DocStringDoc::DocStringDoc(String docs) { + ObjectPtr n = make_object(); + n->comment = docs; + this->data_ = std::move(n); +} + TVM_REGISTER_NODE_TYPE(DocNode); TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") .set_body_typed([](Doc doc, Array source_paths) { @@ -365,6 +377,16 @@ TVM_REGISTER_GLOBAL("script.printer.ClassDoc") return ClassDoc(name, decorators, body); }); +TVM_REGISTER_NODE_TYPE(CommentDocNode); +TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) { + return CommentDoc(comment); +}); + +TVM_REGISTER_NODE_TYPE(DocStringDocNode); +TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) { + return DocStringDoc(docs); +}); + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index a3a5c06ede0d..8df599347f07 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -316,6 +316,10 @@ void DocPrinter::PrintDoc(const Doc& doc) { PrintTypedDoc(GetRef(doc_node)); } else if (const auto* doc_node = doc.as()) { PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); } else { LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); throw; diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index 7851ce061b0d..f5cf40a23357 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -204,6 +204,16 @@ class DocPrinter { */ virtual void PrintTypedDoc(const ClassDoc& doc) = 0; + /*! + * \brief Virtual method to print a CommentDoc + */ + virtual void PrintTypedDoc(const CommentDoc& doc) = 0; + + /*! + * \brief Virtual method to print a DocStringDoc + */ + virtual void PrintTypedDoc(const DocStringDoc& doc) = 0; + /*! * \brief Increase the indent level of any content to be * printed after this call diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index ce6b8e7f423c..334f76f72280 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -169,6 +169,8 @@ class PythonDocPrinter : public DocPrinter { void PrintTypedDoc(const ScopeDoc& doc) final; void PrintTypedDoc(const FunctionDoc& doc) final; void PrintTypedDoc(const ClassDoc& doc) final; + void PrintTypedDoc(const CommentDoc& doc) final; + void PrintTypedDoc(const DocStringDoc& doc) final; private: void NewLineWithoutIndent() { output_ << "\n"; } @@ -253,11 +255,19 @@ class PythonDocPrinter : public DocPrinter { } } - void MaybePrintCommentWithNewLine(const StmtDoc& stmt) { + void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) { if (stmt->comment.defined()) { std::vector comment_lines = support::Split(stmt->comment.value(), '\n'); + bool first_line = true; for (const std::string& line : comment_lines) { - output_ << "# " << line; + if (first_line) { + output_ << "# " << line; + first_line = false; + } else { + NewLine() << "# " << line; + } + } + if (new_line) { NewLine(); } } @@ -523,7 +533,7 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) { } void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) { - MaybePrintCommentWithNewLine(doc); + MaybePrintCommenMultiLines(doc, true); output_ << "if "; PrintDoc(doc->predicate); output_ << ":"; @@ -538,7 +548,7 @@ void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) { } void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) { - MaybePrintCommentWithNewLine(doc); + MaybePrintCommenMultiLines(doc, true); output_ << "while "; PrintDoc(doc->predicate); output_ << ":"; @@ -547,7 +557,7 @@ void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) { } void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) { - MaybePrintCommentWithNewLine(doc); + MaybePrintCommenMultiLines(doc, true); output_ << "for "; if (const auto* tuple = doc->lhs.as()) { if (tuple->elements.size() == 1) { @@ -567,7 +577,7 @@ void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) { } void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) { - MaybePrintCommentWithNewLine(doc); + MaybePrintCommenMultiLines(doc, true); output_ << "with "; PrintDoc(doc->rhs); if (doc->lhs != nullptr) { @@ -642,6 +652,18 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { NewLineWithoutIndent(); } +void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) { + if (doc->comment.defined()) { + MaybePrintCommenMultiLines(doc, false); + } +} + +void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) { + if (doc->comment.defined() && !doc->comment.value().empty()) { + output_ << "\"\"\"" << doc->comment.value() << "\"\"\""; + } +} + String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { if (cfg->num_context_lines < 0) { cfg->num_context_lines = std::numeric_limits::max(); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 4a246e169276..7f7857dba671 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -119,7 +119,8 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) { return s.value(); } } - Doc doc = IRDocsifier(cfg)->AsDoc(mod, ObjectPath::Root()); + IRDocsifier d(cfg); + Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root())); return DocToPythonScript(doc, cfg); } diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index d20756e6081a..a05030516f3f 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -36,6 +36,7 @@ namespace printer { /*! \brief Creates the IR common prefix, which is by default `I` */ inline ExprDoc IR(const IRDocsifier& d, const String& attr) { + d->ir_usage.insert("ir"); return IdDoc(d->cfg->ir_prefix)->Attr(attr); } diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index fbcc2fca3b4b..65f3db5b4fec 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -153,6 +153,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (implicit_root_block) { tir::Block root_block = implicit_root_block.value(); ObjectPath root_block_p = p->Attr("body")->Attr("body"); + (*frame)->stmts.push_back(CommentDoc("with T.block(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { tir::Buffer buffer = root_block->alloc_buffers[i]; @@ -181,7 +182,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) { - Doc doc = IRDocsifier(cfg)->AsDoc(obj, ObjectPath::Root()); + IRDocsifier d(cfg); + Doc doc = HeaderWrapper(d, d->AsDoc(obj, ObjectPath::Root())); return DocToPythonScript(doc, cfg); } diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 88094ee816ca..0eead9a57713 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -74,6 +74,7 @@ class TIRFrame : public Frame { /*! \brief Creates the TIR common prefix, which is by default `T` */ inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { + d->ir_usage.insert("tir"); return IdDoc(d->cfg->tir_prefix)->Attr(attr); } diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index cb20eb363ddd..e90fbc0fb39d 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -69,6 +69,26 @@ inline std::string DType2Str(const runtime::DataType& dtype) { return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype); } +/*! \brief Add headers as comments to doc if needed */ +inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) { + if (d->ir_usage.size()) { + Array stmts; + if (d->ir_usage.count("ir")) { + stmts.push_back(CommentDoc("from tvm.script import ir as " + d->cfg->ir_prefix)); + } + if (d->ir_usage.count("tir")) { + stmts.push_back(CommentDoc("from tvm.script import tir as " + d->cfg->tir_prefix)); + } + if (d->ir_usage.count("relax")) { + stmts.push_back(CommentDoc("from tvm.script import relax as " + d->cfg->relax_prefix)); + } + stmts.push_back(CommentDoc("")); + stmts.push_back(Downcast(doc)); + return StmtBlockDoc(stmts); + } + return doc; +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 16a0c31ac364..6353627c5814 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -29,7 +29,9 @@ AttrAccessDoc, CallDoc, ClassDoc, + CommentDoc, DictDoc, + DocStringDoc, ExprStmtDoc, ForDoc, FunctionDoc, @@ -505,6 +507,32 @@ def test_class_doc(decorators, body): assert list(doc.body) == body +@pytest.mark.parametrize( + "comment", + [ + "", + "test comment 1", + "test comment 1\ntest comment 1", + ], +) +def test_comment_doc(comment): + doc = CommentDoc(comment) + assert doc.comment == comment + + +@pytest.mark.parametrize( + "comment", + [ + "", + "test comment 1", + "test comment 1\ntest comment 1", + ], +) +def test_doc_string_doc(comment): + doc = DocStringDoc(comment) + assert doc.comment == comment + + def test_stmt_doc_comment(): doc = ExprStmtDoc(IdDoc("x")) assert doc.comment is None diff --git a/tests/python/unittest/test_tvmscript_printer_ir.py b/tests/python/unittest/test_tvmscript_printer_ir.py index c3da3d8c702b..6b3ac19a5ef8 100644 --- a/tests/python/unittest/test_tvmscript_printer_ir.py +++ b/tests/python/unittest/test_tvmscript_printer_ir.py @@ -37,6 +37,9 @@ def test_ir_module(): _assert_print( mod, """ +# from tvm.script import ir as I +# from tvm.script import tir as T + @I.ir_module class Module: @T.prim_func diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index d87f9ec69e05..75beb59d02cf 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -23,7 +23,9 @@ AssignDoc, CallDoc, ClassDoc, + CommentDoc, DictDoc, + DocStringDoc, ExprStmtDoc, ForDoc, FunctionDoc, @@ -53,7 +55,7 @@ def format_script(s: str) -> str: non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] if not non_empty_lines: # no actual content - return "\n" + return "" line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] spaces_to_remove = min(line_indents) @@ -887,6 +889,58 @@ def test_print_class_doc(decorators, body, expected): assert to_python_script(doc) == format_script(expected) +@pytest.mark.parametrize( + "comment, expected", + [ + ( + "", + "", + ), + ( + "test comment 1", + "# test comment 1", + ), + ( + "test comment 1\ntest comment 2", + """ + # test comment 1 + # test comment 2 + """, + ), + ], + ids=itertools.count(), +) +def test_print_comment_doc(comment, expected): + doc = CommentDoc(comment) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "comment, expected", + [ + ( + "", + "", + ), + ( + "test comment 1", + '"""test comment 1"""', + ), + ( + "test comment 1\ntest comment 2", + ''' + """test comment 1 + test comment 2""" + ''', + ), + ], + ids=itertools.count(), +) +def test_print_doc_string_doc(comment, expected): + doc = DocStringDoc(comment) + assert to_python_script(doc) == format_script(expected) + + @pytest.mark.parametrize( "doc, comment, expected", [ diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index ec69c54396c3..49a33cd0f0e8 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -41,6 +41,8 @@ def test_prim_func(): _assert_print( func, expected=""" +# from tvm.script import tir as T + @T.prim_func def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): T.evaluate(0)""", @@ -62,6 +64,8 @@ def test_prim_func_no_sugar_inlined_buffer(): _assert_print( func, expected=""" +# from tvm.script import tir as T + @T.prim_func def main(a: T.handle, B: T.Buffer((256, 256), "float32")): A = T.match_buffer(a, (128, 128)) @@ -86,6 +90,8 @@ def test_prim_func_no_sugar_shared_buffer_data(): _assert_print( func, expected=""" +# from tvm.script import tir as T + @T.prim_func def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) @@ -698,8 +704,12 @@ def block_with_remap_explicitly(): v3 = T.axis.spatial(128, i3 - 1) v4, v5 = T.axis.remap("RS", [i4, i5]) - expected_output = """@T.prim_func + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func def main(): + # with T.block("root"): for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): with T.block("update"): v0 = T.axis.spatial(128, i0 + 1) @@ -731,8 +741,12 @@ def root_block_explicitly(): with T.block(): T.evaluate(0) - expected_output = """@T.prim_func + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func def main(): + # with T.block("root"): a = T.alloc_buffer((128, 128)) for i, j in T.grid(128, 128): with T.block(""):