From f9de7aa2586d8f20eec6b77636ff651bbda62e54 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Sat, 17 Sep 2022 16:54:01 -0400 Subject: [PATCH] [TVMScript] Add more helper functions to the printer infra (#12829) This PR is split from https://github.com/apache/tvm/pull/12492, to make the necessary updates to the printer infra for future PRs of TIR printer. Tracking issue: https://github.com/apache/tvm/issues/11912 Co-authored-by: Greg Bonik --- include/tvm/script/printer/doc.h | 64 +++++++++++++ .../script/printer/traced_object_functor.h | 37 +------- include/tvm/script/printer/var_table.h | 11 +++ src/script/printer/doc.cc | 30 ++++-- src/script/printer/ir_docsifier.cc | 2 +- src/script/printer/utils.h | 93 +++++++++++++++++++ src/script/printer/var_table.cc | 3 +- .../cpp/tvmscript_printer_irdocsifier_test.cc | 13 ++- ...ript_printer_traced_object_functor_test.cc | 37 ++++---- 9 files changed, 228 insertions(+), 62 deletions(-) create mode 100644 src/script/printer/utils.h diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 72f343354b1b..1ee7fd6a7fd4 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace tvm { namespace script { @@ -87,6 +88,15 @@ class ExprDocNode : public DocNode { */ ExprDoc Attr(String attr) const; + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param attr The attribute to access. + * + * The ObjectPath of attr will be pushed to the source_path of the returned + * doc. + */ + ExprDoc Attr(TracedObject attr) const; + /*! * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. @@ -242,6 +252,7 @@ class LiteralDocNode : public ExprDocNode { class LiteralDoc : public ExprDoc { protected: explicit LiteralDoc(ObjectRef value); + LiteralDoc(ObjectRef value, ObjectPath object_path); public: /*! @@ -249,30 +260,83 @@ class LiteralDoc : public ExprDoc { */ static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } + /*! + * \brief Create a LiteralDoc to represent None/null/empty value. + * \param object_path The source path of the returned Doc. + */ + static LiteralDoc None(ObjectPath object_path) { + return LiteralDoc(ObjectRef(nullptr), object_path); + } + /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. */ static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath()); + } /*! * \brief Create a LiteralDoc to represent boolean. * \param v The boolean value. */ static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } + /*! + * \brief Create a LiteralDoc to represent boolean. + * \param v The boolean value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Boolean(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } + /*! + * \brief Create a LiteralDoc to represent float. + * \param v The float value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Float(const TracedObject& v) { + return LiteralDoc(v.Get(), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ static LiteralDoc Str(const String& v) { return LiteralDoc(v); } + /*! + * \brief Create a LiteralDoc to represent string. + * \param v The string value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Str(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index 6caaf8a6e0d5..8f72d139a5a5 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -34,35 +34,6 @@ namespace tvm { namespace script { namespace printer { -namespace { - -namespace detail { -/*! - * \brief Helper template class to extract the type of first argument of a function - * \tparam FType The function type. - */ -template -struct FirstArgTypeGetter; - -template -struct FirstArgTypeGetter { - using T = ArgOne; -}; - -/*! - * \brief Template alias for the type of first argument of a function - * \tparam FType The function type. - * - * The name of public functions are in snake case to be consistent with - * tvm/node/functor.h - */ -template -using FirstArgType = typename detail::FirstArgTypeGetter< - typename tvm::runtime::detail::function_signature::FType>::T; -} // namespace detail - -} // namespace - /* * This type alias and the following free functions are created to reduce the binary bloat * from template and also hide implementation details from this header @@ -156,8 +127,7 @@ class TracedObjectFunctor { * * The diaptch function should have signature `R(TracedObject, Args...)`. */ - template ::ObjectRefType, + template ::value>> TSelf& set_dispatch(String token, TCallable f) { return set_dispatch( @@ -177,9 +147,10 @@ class TracedObjectFunctor { * * Default dispatch function has an empty string as dispatch token. */ - template + template ::value>> TSelf& set_dispatch(TCallable&& f) { - return set_dispatch(kDefaultDispatchToken, std::forward(f)); + return set_dispatch(kDefaultDispatchToken, std::forward(f)); } /*! diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h index 9300a976c569..2cd9335213a3 100644 --- a/include/tvm/script/printer/var_table.h +++ b/include/tvm/script/printer/var_table.h @@ -103,6 +103,17 @@ class VarTableNode : public Object { */ Optional GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; + /*! + * \brief Get the doc for variable. + * \param obj The traced variable object. + * + * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. + */ + template + Optional GetVarDoc(const TracedObject obj) const { + return GetVarDoc(obj.Get(), obj.GetPath()); + } + /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index d6f5ff35ab53..f3b431bd62db 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -27,6 +27,12 @@ namespace printer { ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } +ExprDoc ExprDocNode::Attr(TracedObject attr) const { + auto doc = AttrAccessDoc(GetRef(this), attr.Get()); + doc->source_paths.push_back(attr.GetPath()); + return doc; +} + ExprDoc ExprDocNode::operator[](Array indices) const { return IndexDoc(GetRef(this), indices); } @@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) { this->data_ = std::move(n); } +LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) { + ObjectPtr n = make_object(); + n->value = value; + n->source_paths.push_back(object_path); + this->data_ = std::move(n); +} + IdDoc::IdDoc(String name) { ObjectPtr n = make_object(); n->name = name; @@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") }); TVM_REGISTER_NODE_TYPE(ExprDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method(&ExprDocNode::Attr); +TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr") + .set_body_method(&ExprDocNode::Attr); TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex") .set_body_method(&ExprDocNode::operator[]); TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") @@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") + .set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean") + .set_body_typed(LiteralDoc::Boolean); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") + .set_body_typed(LiteralDoc::Float); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr") + .set_body_typed(LiteralDoc::Str); TVM_REGISTER_NODE_TYPE(IdDocNode); TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index b72ed48db63b..7f032ec50269 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) { // }); // \endcode TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { + .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { String top_dispatch_token = p->dispatch_tokens.back(); ICHECK_NE(top_dispatch_token, ""); ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented."; diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h new file mode 100644 index 000000000000..abe7ce5e9a88 --- /dev/null +++ b/src/script/printer/utils.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_UTILS_H_ +#define TVM_SCRIPT_PRINTER_UTILS_H_ + +#include +#include + +#include + +namespace tvm { +namespace script { +namespace printer { + +template +Array AsDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsDocArray(std::initializer_list&& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto& ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsExprDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + return AsDocArray(refs, ir_docsifier); +} + +template +Array AsExprDocArray(std::initializer_list&& refs, + const IRDocsifier& ir_docsifier) { + return AsDocArray(std::move(refs), ir_docsifier); +} + +inline DictDoc AsDictDoc(const TracedMap& dict, + const IRDocsifier& ir_docsifier) { + Array keys; + Array values; + + for (auto p : dict) { + keys.push_back(LiteralDoc::Str(p.first)); + values.push_back(ir_docsifier->AsExprDoc(p.second)); + } + + auto doc = DictDoc(keys, values); + doc->source_paths.push_back(dict.GetPath()); + return doc; +} + +template +inline ListDoc AsListDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +template +inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_UTILS_H_ diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc index 49ba93f9bcfe..62d8b2f66cc2 100644 --- a/src/script/printer/var_table.cc +++ b/src/script/printer/var_table.cc @@ -99,7 +99,8 @@ TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") obj, [f = std::move(factory)]() { return f(); }, frame); }); TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") - .set_body_method(&VarTableNode::GetVarDoc); + .set_body_method, const ObjectRef&, + const ObjectPath&>(&VarTableNode::GetVarDoc); TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") .set_body_method(&VarTableNode::IsVarDefined); diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc index fcdb5ed04e41..8c68399df222 100644 --- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc +++ b/tests/cpp/tvmscript_printer_irdocsifier_test.cc @@ -45,14 +45,19 @@ class TestObject : public ObjectRef { TVM_REGISTER_NODE_TYPE(TestObjectNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) { return IdDoc("x"); }); + .set_dispatch([](TracedObject obj, IRDocsifier p) { + return IdDoc("x"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { return IdDoc("tir"); }); + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + return IdDoc("tir"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", - [](TracedObject obj, IRDocsifier p) { return IdDoc("relax"); }); + .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { + return IdDoc("relax"); + }); TEST(PrinterIRDocsifierTest, AsDoc) { IRDocsifier p(Map{}); diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc index 374eb609b6cb..d662ce132405 100644 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -33,7 +33,7 @@ class FooObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.FooObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject"; TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object); }; @@ -49,7 +49,7 @@ class BarObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.BarObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject"; TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object); }; @@ -69,8 +69,8 @@ TEST(TracedObjectFunctorTest, NormalRegistration) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar"); @@ -80,8 +80,8 @@ TEST(TracedObjectFunctorTest, RegistrationWithFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); - functor.set_dispatch("tir", ComputeFoo); + functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); + functor.set_dispatch("tir", ComputeFoo); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); @@ -91,9 +91,11 @@ TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); - functor.set_dispatch("relax", [](TracedObject o) -> String { return "Foo relax"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch("relax", + [](TracedObject o) -> String { return "Foo relax"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -119,8 +121,8 @@ TEST(TracedObjectFunctorTest, ExtraArg) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); - functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3); @@ -131,8 +133,9 @@ TEST(TracedObjectFunctorTest, RemoveDispatchFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -158,11 +161,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); } catch (...) { failed = true; } @@ -173,11 +176,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); } catch (...) { failed = true; }