Skip to content

Commit

Permalink
[TVMScript] Python Expression Precedence
Browse files Browse the repository at this point in the history
This PR:

- Handle expression (operator) precedence during Python code printing (`(* 1 (+ 2 3))` prints as
`1 * (2 + 3)`)
- Addresses remaining feedback from previous PR #12112
- Reformats Python import with isort

Tracking issue: #11912
  • Loading branch information
yelite authored and junrushao committed Aug 1, 2022
1 parent c07d77f commit a3c7d73
Show file tree
Hide file tree
Showing 6 changed files with 511 additions and 42 deletions.
4 changes: 2 additions & 2 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ class FunctionDocNode : public StmtDocNode {
/*! \brief Decorators of function. */
Array<ExprDoc> decorators;
/*! \brief The return type of function. */
ExprDoc return_type{nullptr};
Optional<ExprDoc> return_type{NullOpt};
/*! \brief The body of function. */
Array<StmtDoc> body;

Expand Down Expand Up @@ -1100,7 +1100,7 @@ class FunctionDoc : public StmtDoc {
* \param body The body of function.
*/
explicit FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body);
Optional<ExprDoc> return_type, Array<StmtDoc> body);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode);
};

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,15 @@ class FunctionDoc(StmtDoc):
name: IdDoc
args: Sequence[AssignDoc]
decorators: Sequence[ExprDoc]
return_type: ExprDoc
return_type: Optional[ExprDoc]
body: Sequence[StmtDoc]

def __init__(
self,
name: IdDoc,
args: List[AssignDoc],
decorators: List[ExprDoc],
return_type: ExprDoc,
return_type: Optional[ExprDoc],
body: List[StmtDoc],
):
self.__init_handle_by_constructor__(
Expand Down
4 changes: 2 additions & 2 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ ReturnDoc::ReturnDoc(ExprDoc value) {
}

FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body) {
Optional<ExprDoc> return_type, Array<StmtDoc> body) {
ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
n->name = name;
n->args = args;
Expand Down Expand Up @@ -345,7 +345,7 @@ TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value)
TVM_REGISTER_NODE_TYPE(FunctionDocNode);
TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
.set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body) {
Optional<ExprDoc> return_type, Array<StmtDoc> body) {
return FunctionDoc(name, args, decorators, return_type, body);
});

Expand Down
179 changes: 167 additions & 12 deletions src/script/printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,114 @@ namespace tvm {
namespace script {
namespace printer {

/*!
* \brief Operator precedence
*
* This is based on
* https://docs.python.org/3/reference/expressions.html#operator-precedence
*/
enum class ExprPrecedence : int32_t {
/*! \brief Unknown precedence */
kUnkown = 0,
/*! \brief Lambda Expression */
kLambda = 1,
/*! \brief Conditional Expression */
kIfThenElse = 2,
/*! \brief Boolean OR */
kBooleanOr = 3,
/*! \brief Boolean AND */
kBooleanAnd = 4,
/*! \brief Boolean NOT */
kBooleanNot = 5,
/*! \brief Comparisons */
kComparison = 6,
/*! \brief Bitwise OR */
kBitwiseOr = 7,
/*! \brief Bitwise XOR */
kBitwiseXor = 8,
/*! \brief Bitwise AND */
kBitwiseAnd = 9,
/*! \brief Shift Operators */
kShift = 10,
/*! \brief Addition and subtraction */
kAdd = 11,
/*! \brief Multiplication, division, floor division, remainder */
kMult = 12,
/*! \brief Positive negative and bitwise NOT */
kUnary = 13,
/*! \brief Exponentiation */
kExp = 14,
/*! \brief Index access, attribute access, call and atom expression */
kIdentity = 15,
};

#define DOC_PRECEDENCE_ENTRY(RefType, Precedence) \
{ RefType::ContainerType::RuntimeTypeIndex(), ExprPrecedence::Precedence }

ExprPrecedence GetExprPrecedence(const ExprDoc& doc) {
// Key is the value of OperationDocNode::Kind
static const std::vector<ExprPrecedence> op_kind_precedence = []() {
using OpKind = OperationDocNode::Kind;
std::map<OpKind, ExprPrecedence> raw_table = {
{OpKind::kUSub, ExprPrecedence::kUnary}, //
{OpKind::kInvert, ExprPrecedence::kUnary}, //
{OpKind::kAdd, ExprPrecedence::kAdd}, //
{OpKind::kSub, ExprPrecedence::kAdd}, //
{OpKind::kMult, ExprPrecedence::kMult}, //
{OpKind::kDiv, ExprPrecedence::kMult}, //
{OpKind::kFloorDiv, ExprPrecedence::kMult}, //
{OpKind::kMod, ExprPrecedence::kMult}, //
{OpKind::kPow, ExprPrecedence::kExp}, //
{OpKind::kLShift, ExprPrecedence::kShift}, //
{OpKind::kRShift, ExprPrecedence::kShift}, //
{OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, //
{OpKind::kBitOr, ExprPrecedence::kBitwiseOr}, //
{OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, //
{OpKind::kLt, ExprPrecedence::kComparison}, //
{OpKind::kLtE, ExprPrecedence::kComparison}, //
{OpKind::kEq, ExprPrecedence::kComparison}, //
{OpKind::kNotEq, ExprPrecedence::kComparison}, //
{OpKind::kGt, ExprPrecedence::kComparison}, //
{OpKind::kGtE, ExprPrecedence::kComparison}, //
{OpKind::kIfThenElse, ExprPrecedence::kIfThenElse}, //
};

std::vector<ExprPrecedence> table;
table.resize(static_cast<int>(OperationDocNode::Kind::kSpecialEnd) + 1);

for (const auto& kv : raw_table) {
table[static_cast<int>(kv.first)] = kv.second;
}

return table;
}();

// Key is the type index of Doc
static const std::unordered_map<uint32_t, ExprPrecedence> doc_type_precedence = {
DOC_PRECEDENCE_ENTRY(LiteralDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(IdDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(AttrAccessDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(IndexDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(CallDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(LambdaDoc, kLambda), //
DOC_PRECEDENCE_ENTRY(TupleDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(ListDoc, kIdentity), //
DOC_PRECEDENCE_ENTRY(DictDoc, kIdentity), //
};

if (const auto* op_doc = doc.as<OperationDocNode>()) {
ExprPrecedence precedence = op_kind_precedence[static_cast<int>(op_doc->kind)];
ICHECK(precedence != ExprPrecedence::kUnkown)
<< "Precedence for operator " << static_cast<int>(op_doc->kind) << " is unknown";
return precedence;
} else if (doc_type_precedence.find(doc->type_index()) != doc_type_precedence.end()) {
return doc_type_precedence.at(doc->type_index());
} else {
ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown";
throw;
}
}

class PythonDocPrinter : public DocPrinter {
public:
explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {}
Expand Down Expand Up @@ -98,6 +206,42 @@ class PythonDocPrinter : public DocPrinter {
}
}

/*!
* \brief Print expression and add parenthesis if needed.
*/
void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence doc_precedence = GetExprPrecedence(doc);
if (doc_precedence < parent_precedence ||
(parenthesis_for_same_precedence && doc_precedence == parent_precedence)) {
output_ << "(";
PrintDoc(doc);
output_ << ")";
} else {
PrintDoc(doc);
}
}

/*!
* \brief Print expression and add parenthesis if doc has lower precedence than parent.
*/
void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence parent_precedence = GetExprPrecedence(parent);
return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence);
}

/*!
* \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent.
*
* This function should be used to print an child expression that needs to be wrapped
* by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b`
* and the `b` and `c` in `a if b else c`.
*/
void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) {
PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence*/ true);
}

void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
Expand Down Expand Up @@ -161,12 +305,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; }

void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
output_ << "." << doc->name;
}

void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
if (doc->indices.size() == 0) {
output_ << "[()]";
} else {
Expand Down Expand Up @@ -226,29 +370,38 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) {
// Unary Operators
ICHECK_EQ(doc->operands.size(), 1);
output_ << OperatorToString(doc->kind);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
} else if (doc->kind == OpKind::kPow) {
// Power operator is different than other binary operators
// It's right-associative and binds less tightly than unary operator on its right.
// https://docs.python.org/3/reference/expressions.html#the-power-operator
// https://docs.python.org/3/reference/expressions.html#operator-precedence
ICHECK_EQ(doc->operands.size(), 2);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " ** ";
PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
} else if (doc->kind < OpKind::kBinaryEnd) {
// Binary Operator
ICHECK_EQ(doc->operands.size(), 2);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
output_ << " " << OperatorToString(doc->kind) << " ";
PrintDoc(doc->operands[1]);
PrintChildExprConservatively(doc->operands[1], doc);
} else if (doc->kind == OpKind::kIfThenElse) {
ICHECK_EQ(doc->operands.size(), 3)
<< "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size();
PrintDoc(doc->operands[1]);
PrintChildExpr(doc->operands[1], doc);
output_ << " if ";
PrintDoc(doc->operands[0]);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " else ";
PrintDoc(doc->operands[2]);
PrintChildExprConservatively(doc->operands[2], doc);
} else {
LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind);
throw;
}
}

void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
PrintDoc(doc->callee);
PrintChildExpr(doc->callee, doc);

output_ << "(";

Expand Down Expand Up @@ -285,7 +438,7 @@ void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
output_ << "lambda ";
PrintJoinedDocs(doc->args, ", ");
output_ << ": ";
PrintDoc(doc->body);
PrintChildExpr(doc->body, doc);
}

void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
Expand Down Expand Up @@ -444,8 +597,10 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
PrintJoinedDocs(doc->args, ", ");
output_ << ")";

output_ << " -> ";
PrintDoc(doc->return_type);
if (doc->return_type.defined()) {
output_ << " -> ";
PrintDoc(doc->return_type.value());
}

output_ << ":";

Expand Down
47 changes: 29 additions & 18 deletions tests/python/unittest/test_tvmscript_printer_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,31 @@

import pytest

import tvm
from tvm.script.printer.doc import (
LiteralDoc,
IdDoc,
AssertDoc,
AssignDoc,
AttrAccessDoc,
IndexDoc,
CallDoc,
OperationKind,
OperationDoc,
ClassDoc,
DictDoc,
ExprStmtDoc,
ForDoc,
FunctionDoc,
IdDoc,
IfDoc,
IndexDoc,
LambdaDoc,
TupleDoc,
ListDoc,
DictDoc,
LiteralDoc,
OperationDoc,
OperationKind,
ReturnDoc,
ScopeDoc,
SliceDoc,
StmtBlockDoc,
AssignDoc,
IfDoc,
TupleDoc,
WhileDoc,
ForDoc,
ScopeDoc,
ExprStmtDoc,
AssertDoc,
ReturnDoc,
FunctionDoc,
ClassDoc,
)


Expand Down Expand Up @@ -450,6 +451,13 @@ def test_return_doc():
[IdDoc("test"), IdDoc("test2")],
],
)
@pytest.mark.parametrize(
"return_type",
[
None,
LiteralDoc(None),
],
)
@pytest.mark.parametrize(
"body",
[
Expand All @@ -458,9 +466,8 @@ def test_return_doc():
[ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
],
)
def test_function_doc(args, decorators, body):
def test_function_doc(args, decorators, return_type, body):
name = IdDoc("name")
return_type = LiteralDoc(None)

doc = FunctionDoc(name, args, decorators, return_type, body)

Expand Down Expand Up @@ -504,3 +511,7 @@ def test_stmt_doc_comment():
comment = "test comment"
doc.comment = comment
assert doc.comment == comment


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit a3c7d73

Please sign in to comment.