Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Comments and docstrings printing #13839

Merged
merged 2 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,50 @@ class ClassDoc : public StmtDoc {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode);
};

/*!
* \brief Doc that represents comment.
*
* \sa CommentDoc
*/
class CommentDocNode : public StmtDocNode {
public:
static constexpr const char* _type_key = "script.printer.CommentDoc";
TVM_DECLARE_FINAL_OBJECT_INFO(CommentDocNode, StmtDocNode);
};

/*!
* \brief Reference type of CommentDocNode.
*
* \sa CommentDocNode
*/
class CommentDoc : public StmtDoc {
public:
explicit CommentDoc(String comment);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CommentDoc, StmtDoc, CommentDocNode);
};

/*!
* \brief Doc that represents docstring.
*
* \sa DocStringDoc
*/
class DocStringDocNode : public StmtDocNode {
public:
static constexpr const char* _type_key = "script.printer.DocStringDoc";
TVM_DECLARE_FINAL_OBJECT_INFO(DocStringDocNode, StmtDocNode);
};

/*!
* \brief Reference type of DocStringDocNode.
*
* \sa DocStringDocNode
*/
class DocStringDoc : public StmtDoc {
public:
explicit DocStringDoc(String docs);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DocStringDoc, StmtDoc, DocStringDocNode);
};

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/ir_docsifier_functor.h>

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -148,6 +149,8 @@ class IRDocsifierNode : public Object {
std::unordered_set<String> defined_names;
/*! \brief Common prefixes of variable usages */
std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
/*! \brief The IR usages for headers printing */
std::unordered_set<std::string> ir_usage;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("frames", &frames);
Expand All @@ -156,6 +159,7 @@ class IRDocsifierNode : public Object {
// `obj2info` is not visited
// `defined_names` is not visited
// `common_prefix` is not visited
// `ir_usage` is not visited
}

static constexpr const char* _type_key = "script.printer.IRDocsifier";
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,23 @@ def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]):
decorators,
body,
)


@register_object("script.printer.CommentDoc")
class CommentDoc(StmtDoc):
"""Doc that represents comment."""

def __init__(self, comment: str):
self.__init_handle_by_constructor__(
_ffi_api.CommentDoc, comment # type: ignore # pylint: disable=no-member
)


@register_object("script.printer.DocStringDoc")
class DocStringDoc(StmtDoc):
"""Doc that represents docstring."""

def __init__(self, docs: str):
self.__init_handle_by_constructor__(
_ffi_api.DocStringDoc, docs # type: ignore # pylint: disable=no-member
)
22 changes: 22 additions & 0 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
this->data_ = std::move(n);
}

CommentDoc::CommentDoc(String comment) {
ObjectPtr<CommentDocNode> n = make_object<CommentDocNode>();
n->comment = comment;
this->data_ = std::move(n);
}

DocStringDoc::DocStringDoc(String docs) {
ObjectPtr<DocStringDocNode> n = make_object<DocStringDocNode>();
n->comment = docs;
this->data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DocNode);
TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
.set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
Expand Down Expand Up @@ -365,6 +377,16 @@ TVM_REGISTER_GLOBAL("script.printer.ClassDoc")
return ClassDoc(name, decorators, body);
});

TVM_REGISTER_NODE_TYPE(CommentDocNode);
TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) {
return CommentDoc(comment);
});

TVM_REGISTER_NODE_TYPE(DocStringDocNode);
TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) {
return DocStringDoc(docs);
});

} // namespace printer
} // namespace script
} // namespace tvm
4 changes: 4 additions & 0 deletions src/script/printer/doc_printer/base_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
} else if (const auto* doc_node = doc.as<CommentDocNode>()) {
PrintTypedDoc(GetRef<CommentDoc>(doc_node));
} else if (const auto* doc_node = doc.as<DocStringDocNode>()) {
PrintTypedDoc(GetRef<DocStringDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
Expand Down
10 changes: 10 additions & 0 deletions src/script/printer/doc_printer/base_doc_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;

/*!
* \brief Virtual method to print a CommentDoc
*/
virtual void PrintTypedDoc(const CommentDoc& doc) = 0;

/*!
* \brief Virtual method to print a DocStringDoc
*/
virtual void PrintTypedDoc(const DocStringDoc& doc) = 0;

/*!
* \brief Increase the indent level of any content to be
* printed after this call
Expand Down
34 changes: 28 additions & 6 deletions src/script/printer/doc_printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const ScopeDoc& doc) final;
void PrintTypedDoc(const FunctionDoc& doc) final;
void PrintTypedDoc(const ClassDoc& doc) final;
void PrintTypedDoc(const CommentDoc& doc) final;
void PrintTypedDoc(const DocStringDoc& doc) final;

private:
void NewLineWithoutIndent() { output_ << "\n"; }
Expand Down Expand Up @@ -253,11 +255,19 @@ class PythonDocPrinter : public DocPrinter {
}
}

void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) {
if (stmt->comment.defined()) {
std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
bool first_line = true;
for (const std::string& line : comment_lines) {
output_ << "# " << line;
if (first_line) {
output_ << "# " << line;
first_line = false;
} else {
NewLine() << "# " << line;
}
}
if (new_line) {
NewLine();
}
}
Expand Down Expand Up @@ -523,7 +533,7 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "if ";
PrintDoc(doc->predicate);
output_ << ":";
Expand All @@ -538,7 +548,7 @@ void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "while ";
PrintDoc(doc->predicate);
output_ << ":";
Expand All @@ -547,7 +557,7 @@ void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "for ";
if (const auto* tuple = doc->lhs.as<TupleDocNode>()) {
if (tuple->elements.size() == 1) {
Expand All @@ -567,7 +577,7 @@ void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
}

void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
MaybePrintCommentWithNewLine(doc);
MaybePrintCommenMultiLines(doc, true);
output_ << "with ";
PrintDoc(doc->rhs);
if (doc->lhs != nullptr) {
Expand Down Expand Up @@ -642,6 +652,18 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
NewLineWithoutIndent();
}

void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) {
if (doc->comment.defined()) {
MaybePrintCommenMultiLines(doc, false);
}
}

void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) {
if (doc->comment.defined() && !doc->comment.value().empty()) {
output_ << "\"\"\"" << doc->comment.value() << "\"\"\"";
}
}

String DocToPythonScript(Doc doc, const PrinterConfig& cfg) {
if (cfg->num_context_lines < 0) {
cfg->num_context_lines = std::numeric_limits<int32_t>::max();
Expand Down
3 changes: 2 additions & 1 deletion src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) {
return s.value();
}
}
Doc doc = IRDocsifier(cfg)->AsDoc(mod, ObjectPath::Root());
IRDocsifier d(cfg);
Doc doc = HeaderWrapper(d, d->AsDoc(mod, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
}

Expand Down
1 change: 1 addition & 0 deletions src/script/printer/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace printer {

/*! \brief Creates the IR common prefix, which is by default `I` */
inline ExprDoc IR(const IRDocsifier& d, const String& attr) {
d->ir_usage.insert("ir");
return IdDoc(d->cfg->ir_prefix)->Attr(attr);
}

Expand Down
4 changes: 3 additions & 1 deletion src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (implicit_root_block) {
tir::Block root_block = implicit_root_block.value();
ObjectPath root_block_p = p->Attr("body")->Attr("body");
(*frame)->stmts.push_back(CommentDoc("with T.block(\"root\"):"));
// Handle root block `alloc_buffer`
for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
tir::Buffer buffer = root_block->alloc_buffers[i];
Expand Down Expand Up @@ -181,7 +182,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});

std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) {
Doc doc = IRDocsifier(cfg)->AsDoc(obj, ObjectPath::Root());
IRDocsifier d(cfg);
Doc doc = HeaderWrapper(d, d->AsDoc(obj, ObjectPath::Root()));
return DocToPythonScript(doc, cfg);
}

Expand Down
1 change: 1 addition & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class TIRFrame : public Frame {

/*! \brief Creates the TIR common prefix, which is by default `T` */
inline ExprDoc TIR(const IRDocsifier& d, const String& attr) {
d->ir_usage.insert("tir");
return IdDoc(d->cfg->tir_prefix)->Attr(attr);
}

Expand Down
20 changes: 20 additions & 0 deletions src/script/printer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,26 @@ inline std::string DType2Str(const runtime::DataType& dtype) {
return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype);
}

/*! \brief Add headers as comments to doc if needed */
inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) {
if (d->ir_usage.size()) {
Array<StmtDoc> stmts;
if (d->ir_usage.count("ir")) {
stmts.push_back(CommentDoc("from tvm.script import ir as " + d->cfg->ir_prefix));
}
if (d->ir_usage.count("tir")) {
stmts.push_back(CommentDoc("from tvm.script import tir as " + d->cfg->tir_prefix));
}
if (d->ir_usage.count("relax")) {
stmts.push_back(CommentDoc("from tvm.script import relax as " + d->cfg->relax_prefix));
}
stmts.push_back(CommentDoc(""));
stmts.push_back(Downcast<StmtDoc>(doc));
return StmtBlockDoc(stmts);
}
return doc;
}

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
AttrAccessDoc,
CallDoc,
ClassDoc,
CommentDoc,
DictDoc,
DocStringDoc,
ExprStmtDoc,
ForDoc,
FunctionDoc,
Expand Down Expand Up @@ -505,6 +507,32 @@ def test_class_doc(decorators, body):
assert list(doc.body) == body


@pytest.mark.parametrize(
"comment",
[
"",
"test comment 1",
"test comment 1\ntest comment 1",
],
)
def test_comment_doc(comment):
doc = CommentDoc(comment)
assert doc.comment == comment


@pytest.mark.parametrize(
"comment",
[
"",
"test comment 1",
"test comment 1\ntest comment 1",
],
)
def test_doc_string_doc(comment):
doc = DocStringDoc(comment)
assert doc.comment == comment


def test_stmt_doc_comment():
doc = ExprStmtDoc(IdDoc("x"))
assert doc.comment is None
Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def test_ir_module():
_assert_print(
mod,
"""
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
@T.prim_func
Expand Down
Loading