-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TVMScript] Printer VarTable (#12336)
This PR: - Adds VarTable for the new TVMScript Printer Compared to the prototype version, this: - Removes unnecessary public methods. - GetObjectName - GetUniqueName - Add Frame parameter for `Define` methods. VarTable will add callback to Frame to remove variable when Frame exits. - Changes DocFactory from `ExprDoc(ObjectPath)` to `ExprDoc()` to simplify var definition. Tracking issue: #11912
- Loading branch information
Showing
5 changed files
with
617 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_ | ||
#define TVM_SCRIPT_PRINTER_VAR_TABLE_H_ | ||
|
||
#include <tvm/node/node.h> | ||
#include <tvm/node/object_path.h> | ||
#include <tvm/script/printer/doc.h> | ||
#include <tvm/script/printer/frame.h> | ||
#include <tvm/script/printer/traced_object.h> | ||
|
||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
namespace tvm { | ||
namespace script { | ||
namespace printer { | ||
|
||
/*! | ||
* \brief Variable Table manages mapping from variable object to ExprDoc during | ||
* the process of printing TVMScript. | ||
* | ||
* The value type of this map is ExprDoc rather than IdDoc or String. It's | ||
* because variables can be implicitly defined. For example in TIR buffer (tir::Buffer), | ||
* `buf->data` is a variable, while its representation in TVMScript should be an | ||
* expression `x.data`, where `x` is the variable for the buffer itself. | ||
*/ | ||
class VarTableNode : public Object { | ||
public: | ||
void VisitAttrs(AttrVisitor*) {} | ||
|
||
/*! | ||
* \brief Define variable by name. | ||
* \param obj The variable object. | ||
* \param name_hint The hint for variable name. | ||
* \param object_path The object_path for the returned ExprDoc. | ||
* \param frame The frame that this variable is defined in. | ||
* | ||
* \return The id doc for this variable. | ||
* | ||
* This function will rename the variable to avoid name conflict with other variables | ||
* in the table. | ||
*/ | ||
IdDoc Define(const ObjectRef& obj, const String& name_hint, const ObjectPath& object_path, | ||
const Frame& frame); | ||
|
||
/*! | ||
* \brief Define variable by name. | ||
* \param obj The variable object. | ||
* \param name_hint The hint for variable name. | ||
* \param frame The frame that this variable is defined in. | ||
* | ||
* \return The id doc for this variable. | ||
* | ||
* This is a shortcut version of `Define` which accepts a traced string. | ||
*/ | ||
IdDoc Define(const ObjectRef& obj, const TracedObject<String>& name_hint, const Frame& frame) { | ||
return Define(obj, name_hint.Get(), name_hint.GetPath(), frame); | ||
} | ||
|
||
using DocFactory = std::function<ExprDoc()>; | ||
|
||
/*! | ||
* \brief Define variable by doc factory. | ||
* \param obj The variable object. | ||
* \param doc_factory The function to return an ExprDoc object for this variable. | ||
* \param frame The frame that this variable is defined in. | ||
* | ||
* This function is a special form of `Define`. Variable is mapped to ExprDoc rather | ||
* than IdDoc. It's useful when a variable is implicitly defined without a name, like | ||
* the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`. | ||
* | ||
* This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to | ||
* return a new Doc object every time it's called, as the returned doc will have | ||
* different `soruce_path`. Currently there isn't a good way to deep copy a TVMObject | ||
* so VarTable needs to call a factory function to get a freshly-constructed Doc object | ||
* every time GetVarDoc is called. | ||
*/ | ||
void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame); | ||
|
||
/*! | ||
* \brief Get the doc for variable. | ||
* \param obj The variable object. | ||
* \param object_path The object path for the variable. | ||
* | ||
* \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. | ||
*/ | ||
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; | ||
|
||
/*! | ||
* \brief Check if a variable exists in the table. | ||
* \param obj The variable object. | ||
* | ||
* \return a boolean for whether variable exists. | ||
*/ | ||
bool IsVarDefined(const ObjectRef& obj) const; | ||
|
||
static constexpr const char* _type_key = "script.printer.VarTable"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object); | ||
|
||
private: | ||
void RemoveVar(const ObjectRef& obj); | ||
|
||
struct VariableInfo { | ||
DocFactory doc_factory; | ||
Optional<String> name; | ||
}; | ||
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info; | ||
std::unordered_set<String> defined_names; | ||
}; | ||
|
||
/*! | ||
* \brief Reference type of VarTableNode. | ||
*/ | ||
class VarTable : public ObjectRef { | ||
public: | ||
/*! | ||
* \brief Create an empty VarTable. | ||
*/ | ||
VarTable(); | ||
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, VarTableNode); | ||
}; | ||
|
||
} // namespace printer | ||
} // namespace script | ||
} // namespace tvm | ||
|
||
#endif // TVM_SCRIPT_PRINTER_VAR_TABLE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Functions to print doc into text format""" | ||
|
||
from typing import Callable, Optional | ||
|
||
from tvm._ffi import register_object | ||
from tvm.runtime import Object, ObjectPath | ||
|
||
from . import _ffi_api | ||
from .doc import ExprDoc, IdDoc | ||
from .frame import Frame | ||
|
||
|
||
@register_object("script.printer.VarTable") | ||
class VarTable(Object): | ||
""" | ||
Variable Table manages mapping from variable object to ExprDoc during | ||
the process of printing TVMScript. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Create an empty VarTable. | ||
""" | ||
self.__init_handle_by_constructor__(_ffi_api.VarTable) # type: ignore # pylint: disable=no-member | ||
|
||
def define(self, obj: Object, name_hint: str, object_path: ObjectPath, frame: Frame) -> IdDoc: | ||
""" | ||
Define a variable by name. | ||
Parameters | ||
---------- | ||
obj : Object | ||
The variable object. | ||
name_hint : str | ||
The hint for variable name. | ||
object_path : ObjectPath | ||
The object path to be associated with the returned ExprDoc. | ||
frame : Frame | ||
Then frame that this variable is defined in. | ||
Returns | ||
------- | ||
doc : IdDoc | ||
The doc for this variable. | ||
""" | ||
return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, frame) # type: ignore # pylint: disable=no-member | ||
|
||
def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], frame: Frame) -> None: | ||
""" | ||
Define a variable by ExprDoc. | ||
Parameters | ||
---------- | ||
obj : Object | ||
The variable object. | ||
doc_factory : Callable[[], ExprDoc] | ||
The hint for variable name. | ||
frame : Frame | ||
Then frame that this variable is defined in. | ||
Returns | ||
------- | ||
None | ||
""" | ||
_ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame) # type: ignore # pylint: disable=no-member | ||
|
||
def get_var_doc(self, obj: Object, object_path: ObjectPath) -> Optional[ExprDoc]: | ||
""" | ||
Get the doc for a variable. | ||
Parameters | ||
---------- | ||
obj : Object | ||
The variable object. | ||
object_path : ObjectPath | ||
The object path to be associated with the returned ExprDoc. | ||
Returns | ||
------- | ||
doc : ExprDoc | ||
The doc for this variable. | ||
""" | ||
return _ffi_api.VarTableGetVarDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member | ||
|
||
def is_var_defined(self, obj: Object) -> bool: | ||
""" | ||
Check whether a variable is defined. | ||
Parameters | ||
---------- | ||
obj : Object | ||
The variable object. | ||
Returns | ||
------- | ||
is_defined : bool | ||
Whether the variable is defined. | ||
""" | ||
return _ffi_api.VarTableIsVarDefined(self, obj) # type: ignore # pylint: disable=no-member | ||
|
||
def __contains__(self, obj: Object) -> bool: | ||
return self.is_var_defined(obj) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#include <tvm/node/object_path.h> | ||
#include <tvm/runtime/container/optional.h> | ||
#include <tvm/runtime/logging.h> | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/script/printer/var_table.h> | ||
|
||
namespace tvm { | ||
namespace script { | ||
namespace printer { | ||
|
||
String GenerateUniqueName(const String& name_hint, std::unordered_set<String>* defined_names) { | ||
String name = name_hint; | ||
for (int i = 1; !defined_names->insert(name).second; ++i) { | ||
name = name_hint + "_" + std::to_string(i); | ||
} | ||
return name; | ||
} | ||
|
||
IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint, | ||
const ObjectPath& object_path, const Frame& frame) { | ||
String name = GenerateUniqueName(name_hint, &this->defined_names); | ||
DocFactory doc_factory = [name]() { return IdDoc(name); }; | ||
|
||
auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); | ||
ICHECK(result.second) << "Duplicated object: " << obj; | ||
|
||
IdDoc def_doc(name); | ||
def_doc->source_paths.push_back(object_path); | ||
|
||
frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); | ||
|
||
return def_doc; | ||
} | ||
|
||
void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame) { | ||
ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; | ||
|
||
ICHECK(!doc_factory()->IsInstance<IdDocNode>()) | ||
<< "VarTableNode::Define cannot be used for variable that's mapped to IdDoc."; | ||
|
||
obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}}); | ||
|
||
frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); | ||
} | ||
|
||
Optional<ExprDoc> VarTableNode::GetVarDoc(const ObjectRef& obj, | ||
const ObjectPath& object_path) const { | ||
auto it = obj2info.find(obj); | ||
if (it == obj2info.end()) { | ||
return NullOpt; | ||
} | ||
ExprDoc doc = it->second.doc_factory(); | ||
doc->source_paths.push_back(object_path); | ||
return doc; | ||
} | ||
|
||
bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } | ||
|
||
void VarTableNode::RemoveVar(const ObjectRef& obj) { | ||
auto it = obj2info.find(obj); | ||
ICHECK(it != obj2info.end()) << "No such object: " << obj; | ||
|
||
if (it->second.name.defined()) { | ||
defined_names.erase(it->second.name.value()); | ||
} | ||
obj2info.erase(it); | ||
} | ||
|
||
VarTable::VarTable() { data_ = make_object<VarTableNode>(); } | ||
|
||
TVM_REGISTER_NODE_TYPE(VarTableNode); | ||
TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return VarTable(); }); | ||
TVM_REGISTER_GLOBAL("script.printer.VarTableDefine") | ||
.set_body_method<VarTable, VarTableNode, IdDoc, const ObjectRef&, const String&, | ||
const ObjectPath&, const Frame&>(&VarTableNode::Define); | ||
TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") | ||
.set_body_typed([](VarTable var_table, const ObjectRef& obj, runtime::PackedFunc factory, | ||
Frame frame) { | ||
var_table->DefineByDoc( | ||
obj, [f = std::move(factory)]() { return f(); }, frame); | ||
}); | ||
TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") | ||
.set_body_method<VarTable>(&VarTableNode::GetVarDoc); | ||
TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") | ||
.set_body_method<VarTable>(&VarTableNode::IsVarDefined); | ||
|
||
} // namespace printer | ||
} // namespace script | ||
} // namespace tvm |
Oops, something went wrong.