Skip to content

Commit

Permalink
[IR] fix Yield (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-mxd authored Mar 29, 2023
1 parent 1cc6296 commit 6e0dcf7
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 2 deletions.
32 changes: 32 additions & 0 deletions include/matxscript/ir/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,38 @@ class SliceDoc : public Doc {
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode);
};

/*!
* \brief Doc that represents yield statement.
*
* \sa YieldDoc
*/
class YieldDocNode : public ExprDocNode {
public:
Optional<ExprDoc> value;

void VisitAttrs(AttrVisitor* v) {
ExprDocNode::VisitAttrs(v);
v->Visit("value", &value);
}

static constexpr const char* _type_key = "ir.printer.YieldDoc";
MATXSCRIPT_DECLARE_FINAL_OBJECT_INFO(YieldDocNode, ExprDocNode);
};

/*!
* \brief Reference type of YieldDocNode.
*
* \sa YieldDocNode
*/
class YieldDoc : public ExprDoc {
public:
/*!
* \brief Constructor of YieldDoc.
*/
explicit YieldDoc(Optional<ExprDoc> value);
MATXSCRIPT_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(YieldDoc, ExprDoc, YieldDocNode);
};

/*!
* \brief Doc that represents assign statement.
*
Expand Down
11 changes: 11 additions & 0 deletions src/ir/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ SliceDoc::SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<Exp
this->data_ = std::move(n);
}

YieldDoc::YieldDoc(Optional<ExprDoc> value) {
ObjectPtr<YieldDocNode> n = make_object<YieldDocNode>();
n->value = std::move(value);
this->data_ = std::move(n);
}

AssignDoc::AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
MXCHECK(rhs.defined() || annotation.defined())
<< "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc.";
Expand Down Expand Up @@ -463,6 +469,11 @@ MATXSCRIPT_REGISTER_GLOBAL("ir.printer.SliceDoc")
return SliceDoc(std::move(start), std::move(stop), std::move(step));
});

MATXSCRIPT_REGISTER_NODE_TYPE(YieldDocNode);
MATXSCRIPT_REGISTER_GLOBAL("ir.printer.YieldDoc").set_body_typed([](Optional<ExprDoc> value) {
return YieldDoc(std::move(value));
});

MATXSCRIPT_REGISTER_NODE_TYPE(AssignDocNode);
MATXSCRIPT_REGISTER_GLOBAL("ir.printer.AssignDoc")
.set_body_typed([](ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
Expand Down
2 changes: 2 additions & 0 deletions src/ir/printer/doc_printer/base_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<DictCompDoc>(doc_node));
} else if (const auto* doc_node = doc.as<SliceDocNode>()) {
PrintTypedDoc(GetRef<SliceDoc>(doc_node));
} else if (const auto* doc_node = doc.as<YieldDocNode>()) {
PrintTypedDoc(GetRef<YieldDoc>(doc_node));
} else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
} else if (const auto* doc_node = doc.as<AssignDocNode>()) {
Expand Down
5 changes: 5 additions & 0 deletions src/ir/printer/doc_printer/base_doc_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const SliceDoc& doc) = 0;

/*!
* \brief Virtual method to print a YieldDoc
*/
virtual void PrintTypedDoc(const YieldDoc& doc) = 0;

/*!
* \brief Virtual method to print a StmtBlockDoc
*/
Expand Down
9 changes: 9 additions & 0 deletions src/ir/printer/doc_printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const SetCompDoc& doc) final;
void PrintTypedDoc(const DictCompDoc& doc) final;
void PrintTypedDoc(const SliceDoc& doc) final;
void PrintTypedDoc(const YieldDoc& doc) final;
void PrintTypedDoc(const StmtBlockDoc& doc) final;
void PrintTypedDoc(const AssignDoc& doc) final;
void PrintTypedDoc(const IfDoc& doc) final;
Expand Down Expand Up @@ -598,6 +599,14 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
}
}

void PythonDocPrinter::PrintTypedDoc(const YieldDoc& doc) {
output_ << "yield";
if (doc->value.defined()) {
output_ << " ";
PrintDoc(doc->value.value());
}
}

void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
for (const StmtDoc& stmt : doc->stmts) {
PrintDoc(stmt);
Expand Down
4 changes: 2 additions & 2 deletions src/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::HLOYield>( //
"",
[](ir::HLOYield stmt, ObjectPath p, IRDocsifier d) -> Doc {
// TODO: support HLOYieldDoc;
return CommentDoc("TODO");
ExprDoc value = d->AsDoc<ExprDoc>(stmt->symbol, p->Attr("symbol"));
return ExprStmtDoc(YieldDoc(value));
});

} // namespace ir
Expand Down

0 comments on commit 6e0dcf7

Please sign in to comment.