From 4daa5d68d5453e9e5f9340989e788933031e11b7 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Sat, 13 Aug 2022 16:04:39 -0400 Subject: [PATCH] [TVMScript] Printer VarTable (#12336) This PR: - Adds VarTable for the new TVMScript Printer Compared to the prototype version, this: - Removes unnecessary public methods. - GetObjectName - GetUniqueName - Add Frame parameter for `Define` methods. VarTable will add callback to Frame to remove variable when Frame exits. - Changes DocFactory from `ExprDoc(ObjectPath)` to `ExprDoc()` to simplify var definition. Tracking issue: https://github.com/apache/tvm/issues/11912 --- include/tvm/script/printer/var_table.h | 144 ++++++++++++++++ python/tvm/script/printer/var_table.py | 118 +++++++++++++ src/script/printer/var_table.cc | 108 ++++++++++++ tests/cpp/tvmscript_printer_var_table_test.cc | 158 ++++++++++++++++++ .../test_tvmscript_printer_var_table.py | 89 ++++++++++ 5 files changed, 617 insertions(+) create mode 100644 include/tvm/script/printer/var_table.h create mode 100644 python/tvm/script/printer/var_table.py create mode 100644 src/script/printer/var_table.cc create mode 100644 tests/cpp/tvmscript_printer_var_table_test.cc create mode 100644 tests/python/unittest/test_tvmscript_printer_var_table.py diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h new file mode 100644 index 000000000000..9300a976c569 --- /dev/null +++ b/include/tvm/script/printer/var_table.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_ +#define TVM_SCRIPT_PRINTER_VAR_TABLE_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * \brief Variable Table manages mapping from variable object to ExprDoc during + * the process of printing TVMScript. + * + * The value type of this map is ExprDoc rather than IdDoc or String. It's + * because variables can be implicitly defined. For example in TIR buffer (tir::Buffer), + * `buf->data` is a variable, while its representation in TVMScript should be an + * expression `x.data`, where `x` is the variable for the buffer itself. + */ +class VarTableNode : public Object { + public: + void VisitAttrs(AttrVisitor*) {} + + /*! + * \brief Define variable by name. + * \param obj The variable object. + * \param name_hint The hint for variable name. + * \param object_path The object_path for the returned ExprDoc. + * \param frame The frame that this variable is defined in. + * + * \return The id doc for this variable. + * + * This function will rename the variable to avoid name conflict with other variables + * in the table. + */ + IdDoc Define(const ObjectRef& obj, const String& name_hint, const ObjectPath& object_path, + const Frame& frame); + + /*! + * \brief Define variable by name. + * \param obj The variable object. + * \param name_hint The hint for variable name. + * \param frame The frame that this variable is defined in. + * + * \return The id doc for this variable. + * + * This is a shortcut version of `Define` which accepts a traced string. + */ + IdDoc Define(const ObjectRef& obj, const TracedObject& name_hint, const Frame& frame) { + return Define(obj, name_hint.Get(), name_hint.GetPath(), frame); + } + + using DocFactory = std::function; + + /*! + * \brief Define variable by doc factory. + * \param obj The variable object. + * \param doc_factory The function to return an ExprDoc object for this variable. + * \param frame The frame that this variable is defined in. + * + * This function is a special form of `Define`. Variable is mapped to ExprDoc rather + * than IdDoc. It's useful when a variable is implicitly defined without a name, like + * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc(""), "data")`. + * + * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to + * return a new Doc object every time it's called, as the returned doc will have + * different `soruce_path`. Currently there isn't a good way to deep copy a TVMObject + * so VarTable needs to call a factory function to get a freshly-constructed Doc object + * every time GetVarDoc is called. + */ + void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame); + + /*! + * \brief Get the doc for variable. + * \param obj The variable object. + * \param object_path The object path for the variable. + * + * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. + */ + Optional GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; + + /*! + * \brief Check if a variable exists in the table. + * \param obj The variable object. + * + * \return a boolean for whether variable exists. + */ + bool IsVarDefined(const ObjectRef& obj) const; + + static constexpr const char* _type_key = "script.printer.VarTable"; + TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object); + + private: + void RemoveVar(const ObjectRef& obj); + + struct VariableInfo { + DocFactory doc_factory; + Optional name; + }; + std::unordered_map obj2info; + std::unordered_set defined_names; +}; + +/*! + * \brief Reference type of VarTableNode. + */ +class VarTable : public ObjectRef { + public: + /*! + * \brief Create an empty VarTable. + */ + VarTable(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, VarTableNode); +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_VAR_TABLE_H_ diff --git a/python/tvm/script/printer/var_table.py b/python/tvm/script/printer/var_table.py new file mode 100644 index 000000000000..ea1fa41b3210 --- /dev/null +++ b/python/tvm/script/printer/var_table.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Functions to print doc into text format""" + +from typing import Callable, Optional + +from tvm._ffi import register_object +from tvm.runtime import Object, ObjectPath + +from . import _ffi_api +from .doc import ExprDoc, IdDoc +from .frame import Frame + + +@register_object("script.printer.VarTable") +class VarTable(Object): + """ + Variable Table manages mapping from variable object to ExprDoc during + the process of printing TVMScript. + """ + + def __init__(self): + """ + Create an empty VarTable. + """ + self.__init_handle_by_constructor__(_ffi_api.VarTable) # type: ignore # pylint: disable=no-member + + def define(self, obj: Object, name_hint: str, object_path: ObjectPath, frame: Frame) -> IdDoc: + """ + Define a variable by name. + + Parameters + ---------- + obj : Object + The variable object. + name_hint : str + The hint for variable name. + object_path : ObjectPath + The object path to be associated with the returned ExprDoc. + frame : Frame + Then frame that this variable is defined in. + + Returns + ------- + doc : IdDoc + The doc for this variable. + """ + return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, frame) # type: ignore # pylint: disable=no-member + + def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], frame: Frame) -> None: + """ + Define a variable by ExprDoc. + + Parameters + ---------- + obj : Object + The variable object. + doc_factory : Callable[[], ExprDoc] + The hint for variable name. + frame : Frame + Then frame that this variable is defined in. + + Returns + ------- + None + """ + _ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame) # type: ignore # pylint: disable=no-member + + def get_var_doc(self, obj: Object, object_path: ObjectPath) -> Optional[ExprDoc]: + """ + Get the doc for a variable. + + Parameters + ---------- + obj : Object + The variable object. + object_path : ObjectPath + The object path to be associated with the returned ExprDoc. + + Returns + ------- + doc : ExprDoc + The doc for this variable. + """ + return _ffi_api.VarTableGetVarDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member + + def is_var_defined(self, obj: Object) -> bool: + """ + Check whether a variable is defined. + + Parameters + ---------- + obj : Object + The variable object. + + Returns + ------- + is_defined : bool + Whether the variable is defined. + """ + return _ffi_api.VarTableIsVarDefined(self, obj) # type: ignore # pylint: disable=no-member + + def __contains__(self, obj: Object) -> bool: + return self.is_var_defined(obj) diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc new file mode 100644 index 000000000000..49ba93f9bcfe --- /dev/null +++ b/src/script/printer/var_table.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +String GenerateUniqueName(const String& name_hint, std::unordered_set* defined_names) { + String name = name_hint; + for (int i = 1; !defined_names->insert(name).second; ++i) { + name = name_hint + "_" + std::to_string(i); + } + return name; +} + +IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint, + const ObjectPath& object_path, const Frame& frame) { + String name = GenerateUniqueName(name_hint, &this->defined_names); + DocFactory doc_factory = [name]() { return IdDoc(name); }; + + auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); + ICHECK(result.second) << "Duplicated object: " << obj; + + IdDoc def_doc(name); + def_doc->source_paths.push_back(object_path); + + frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); + + return def_doc; +} + +void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame) { + ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; + + ICHECK(!doc_factory()->IsInstance()) + << "VarTableNode::Define cannot be used for variable that's mapped to IdDoc."; + + obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}}); + + frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); +} + +Optional VarTableNode::GetVarDoc(const ObjectRef& obj, + const ObjectPath& object_path) const { + auto it = obj2info.find(obj); + if (it == obj2info.end()) { + return NullOpt; + } + ExprDoc doc = it->second.doc_factory(); + doc->source_paths.push_back(object_path); + return doc; +} + +bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } + +void VarTableNode::RemoveVar(const ObjectRef& obj) { + auto it = obj2info.find(obj); + ICHECK(it != obj2info.end()) << "No such object: " << obj; + + if (it->second.name.defined()) { + defined_names.erase(it->second.name.value()); + } + obj2info.erase(it); +} + +VarTable::VarTable() { data_ = make_object(); } + +TVM_REGISTER_NODE_TYPE(VarTableNode); +TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return VarTable(); }); +TVM_REGISTER_GLOBAL("script.printer.VarTableDefine") + .set_body_method(&VarTableNode::Define); +TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") + .set_body_typed([](VarTable var_table, const ObjectRef& obj, runtime::PackedFunc factory, + Frame frame) { + var_table->DefineByDoc( + obj, [f = std::move(factory)]() { return f(); }, frame); + }); +TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") + .set_body_method(&VarTableNode::GetVarDoc); +TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") + .set_body_method(&VarTableNode::IsVarDefined); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/cpp/tvmscript_printer_var_table_test.cc b/tests/cpp/tvmscript_printer_var_table_test.cc new file mode 100644 index 000000000000..b447c81ac0b8 --- /dev/null +++ b/tests/cpp/tvmscript_printer_var_table_test.cc @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::script::printer; + +TEST(PrinterVarTableTest, Define) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + ObjectPath object_path = ObjectPath::Root(); + + IdDoc doc = vars->Define(x, "x", object_path, frame); + + ICHECK_EQ(doc->name, "x"); + + IdDoc second_doc = Downcast(vars->GetVarDoc(x, object_path).value()); + + ICHECK_EQ(second_doc->name, "x"); +} + +TEST(PrinterVarTableTest, DefineByDoc) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + ObjectPath object_path = ObjectPath::Root(); + + auto doc_factory = []() { return LiteralDoc::Str("x"); }; + + vars->DefineByDoc(x, doc_factory, frame); + + ExprDoc doc = vars->GetVarDoc(x, object_path).value(); + + ICHECK_EQ(Downcast(Downcast(doc)->value), "x"); +} + +TEST(PrinterVarTableTest, GetVarDocWithUnknownVariable) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + tir::Var y("y"); + ObjectPath object_path = ObjectPath::Root(); + + Doc doc = vars->Define(x, "x", object_path, frame); + ICHECK(!vars->GetVarDoc(y, object_path).defined()); +} + +TEST(PrinterVarTableTest, GetVarDocWithObjectPath) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + ObjectPath object_path = ObjectPath::Root(); + ObjectPath second_object_path = ObjectPath::Root()->Attr("x"); + + IdDoc doc = vars->Define(x, "x", object_path, frame); + ICHECK_EQ(doc->source_paths[0], object_path); + ICHECK_EQ(doc->source_paths.size(), 1); + + Doc second_doc = vars->GetVarDoc(x, second_object_path).value(); + ICHECK_EQ(second_doc->source_paths[0], second_object_path); + ICHECK_EQ(second_doc->source_paths.size(), 1); +} + +TEST(PrinterVarTableTest, IsVarDefined) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + tir::Var y("y"); + ObjectPath object_path = ObjectPath::Root(); + + vars->Define(x, "x", object_path, frame); + ICHECK(vars->IsVarDefined(x)); + ICHECK(!vars->IsVarDefined(y)); +} + +TEST(PrinterVarTableTest, VarRemovedAfterFrameOutOfScope) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + ObjectPath object_path = ObjectPath::Root(); + + vars->Define(x, "x", object_path, frame); + ICHECK(vars->IsVarDefined(x)); + + frame->ExitWithScope(); + ICHECK(!vars->IsVarDefined(x)); +} + +TEST(PrinterVarTableTest, DefineDuplicateName) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + tir::Var y("y"); + ObjectPath object_path = ObjectPath::Root(); + + IdDoc x_doc = vars->Define(x, "x", object_path, frame); + IdDoc y_doc = vars->Define(y, "x", object_path, frame); + + ICHECK_NE(x_doc->name, y_doc->name); +} + +TEST(PrinterVarTableTest, DefineDuplicateVariable) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + ObjectPath object_path = ObjectPath::Root(); + + vars->Define(x, "x", object_path, frame); + + bool failed = false; + try { + vars->Define(x, "x", object_path, frame); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +} + +TEST(PrinterVarTableTest, DefineByDocWithIdDoc) { + VarTable vars; + MetadataFrame frame; + tir::Var x("x"); + ObjectPath object_path = ObjectPath::Root(); + + bool failed = false; + try { + // User has to use `Define` if variable needs to be mapped to IdDoc + vars->DefineByDoc( + x, []() { return IdDoc("x"); }, frame); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +} diff --git a/tests/python/unittest/test_tvmscript_printer_var_table.py b/tests/python/unittest/test_tvmscript_printer_var_table.py new file mode 100644 index 000000000000..eab63a08ddad --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_var_table.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This file tests the FFI binding of script.printer.VarTable. +These only make sure parameter can be passed to the C++ functions +correctly. The test for the functionality of VarTable is in C++. +""" + +from tvm.runtime import ObjectPath +from tvm.script.printer.doc import LiteralDoc +from tvm.script.printer.frame import VarDefFrame +from tvm.script.printer.var_table import VarTable +from tvm.tir import Var + + +def test_define(): + var_table = VarTable() + var_name = "a" + var_obj = Var(var_name, dtype="int32") + object_path = ObjectPath.root().attr("a") + frame = VarDefFrame() + + id_doc = var_table.define(var_obj, var_name, object_path, frame) + + assert id_doc.name == "a" + assert list(id_doc.source_paths) == [object_path] + + id_doc = var_table.get_var_doc(var_obj, object_path) + + assert id_doc.name == "a" + assert list(id_doc.source_paths) == [object_path] + + +def test_define_by_doc(): + var_table = VarTable() + var_name = "a" + var_obj = Var(var_name, dtype="int32") + object_path = ObjectPath.root().attr("a") + frame = VarDefFrame() + + var_table.define_by_doc(var_obj, lambda: LiteralDoc(var_name), frame) + + var_doc = var_table.get_var_doc(var_obj, object_path) + + assert isinstance(var_doc, LiteralDoc) + assert var_doc.value == var_name + assert list(var_doc.source_paths) == [object_path] + + +def test_is_var_defined(): + var_table = VarTable() + a = Var("a", dtype="int32") + object_path = ObjectPath.root().attr("a") + frame = VarDefFrame() + + var_table.define(a, "a", object_path, frame) + + assert var_table.is_var_defined(a) + assert a in var_table + + +def test_var_out_of_scope(): + var_table = VarTable() + var_name = "a" + var_obj = Var(var_name, dtype="int32") + object_path = ObjectPath.root().attr("a") + frame = VarDefFrame() + + var_table.define(var_obj, var_name, object_path, frame) + + with frame: + assert var_obj in var_table + + assert var_obj not in var_table + assert var_table.get_var_doc(var_obj, object_path) is None