Skip to content

Commit

Permalink
[TVMScript] Printer VarTable (#12336)
Browse files Browse the repository at this point in the history
This PR:

- Adds VarTable for the new TVMScript Printer

Compared to the prototype version, this:

- Removes unnecessary public methods.
  - GetObjectName
  - GetUniqueName
- Add Frame parameter for `Define` methods. VarTable will add callback to Frame to remove variable when Frame exits.
- Changes DocFactory from `ExprDoc(ObjectPath)` to `ExprDoc()` to simplify var definition.

Tracking issue: #11912
  • Loading branch information
yelite authored Aug 13, 2022
1 parent 036aa72 commit d33a332
Show file tree
Hide file tree
Showing 5 changed files with 617 additions and 0 deletions.
144 changes: 144 additions & 0 deletions include/tvm/script/printer/var_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_
#define TVM_SCRIPT_PRINTER_VAR_TABLE_H_

#include <tvm/node/node.h>
#include <tvm/node/object_path.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/frame.h>
#include <tvm/script/printer/traced_object.h>

#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace script {
namespace printer {

/*!
* \brief Variable Table manages mapping from variable object to ExprDoc during
* the process of printing TVMScript.
*
* The value type of this map is ExprDoc rather than IdDoc or String. It's
* because variables can be implicitly defined. For example in TIR buffer (tir::Buffer),
* `buf->data` is a variable, while its representation in TVMScript should be an
* expression `x.data`, where `x` is the variable for the buffer itself.
*/
class VarTableNode : public Object {
public:
void VisitAttrs(AttrVisitor*) {}

/*!
* \brief Define variable by name.
* \param obj The variable object.
* \param name_hint The hint for variable name.
* \param object_path The object_path for the returned ExprDoc.
* \param frame The frame that this variable is defined in.
*
* \return The id doc for this variable.
*
* This function will rename the variable to avoid name conflict with other variables
* in the table.
*/
IdDoc Define(const ObjectRef& obj, const String& name_hint, const ObjectPath& object_path,
const Frame& frame);

/*!
* \brief Define variable by name.
* \param obj The variable object.
* \param name_hint The hint for variable name.
* \param frame The frame that this variable is defined in.
*
* \return The id doc for this variable.
*
* This is a shortcut version of `Define` which accepts a traced string.
*/
IdDoc Define(const ObjectRef& obj, const TracedObject<String>& name_hint, const Frame& frame) {
return Define(obj, name_hint.Get(), name_hint.GetPath(), frame);
}

using DocFactory = std::function<ExprDoc()>;

/*!
* \brief Define variable by doc factory.
* \param obj The variable object.
* \param doc_factory The function to return an ExprDoc object for this variable.
* \param frame The frame that this variable is defined in.
*
* This function is a special form of `Define`. Variable is mapped to ExprDoc rather
* than IdDoc. It's useful when a variable is implicitly defined without a name, like
* the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`.
*
* This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to
* return a new Doc object every time it's called, as the returned doc will have
* different `soruce_path`. Currently there isn't a good way to deep copy a TVMObject
* so VarTable needs to call a factory function to get a freshly-constructed Doc object
* every time GetVarDoc is called.
*/
void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame);

/*!
* \brief Get the doc for variable.
* \param obj The variable object.
* \param object_path The object path for the variable.
*
* \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
*/
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const;

/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
*
* \return a boolean for whether variable exists.
*/
bool IsVarDefined(const ObjectRef& obj) const;

static constexpr const char* _type_key = "script.printer.VarTable";
TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object);

private:
void RemoveVar(const ObjectRef& obj);

struct VariableInfo {
DocFactory doc_factory;
Optional<String> name;
};
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
std::unordered_set<String> defined_names;
};

/*!
* \brief Reference type of VarTableNode.
*/
class VarTable : public ObjectRef {
public:
/*!
* \brief Create an empty VarTable.
*/
VarTable();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, VarTableNode);
};

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_VAR_TABLE_H_
118 changes: 118 additions & 0 deletions python/tvm/script/printer/var_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Functions to print doc into text format"""

from typing import Callable, Optional

from tvm._ffi import register_object
from tvm.runtime import Object, ObjectPath

from . import _ffi_api
from .doc import ExprDoc, IdDoc
from .frame import Frame


@register_object("script.printer.VarTable")
class VarTable(Object):
"""
Variable Table manages mapping from variable object to ExprDoc during
the process of printing TVMScript.
"""

def __init__(self):
"""
Create an empty VarTable.
"""
self.__init_handle_by_constructor__(_ffi_api.VarTable) # type: ignore # pylint: disable=no-member

def define(self, obj: Object, name_hint: str, object_path: ObjectPath, frame: Frame) -> IdDoc:
"""
Define a variable by name.
Parameters
----------
obj : Object
The variable object.
name_hint : str
The hint for variable name.
object_path : ObjectPath
The object path to be associated with the returned ExprDoc.
frame : Frame
Then frame that this variable is defined in.
Returns
-------
doc : IdDoc
The doc for this variable.
"""
return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, frame) # type: ignore # pylint: disable=no-member

def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], frame: Frame) -> None:
"""
Define a variable by ExprDoc.
Parameters
----------
obj : Object
The variable object.
doc_factory : Callable[[], ExprDoc]
The hint for variable name.
frame : Frame
Then frame that this variable is defined in.
Returns
-------
None
"""
_ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame) # type: ignore # pylint: disable=no-member

def get_var_doc(self, obj: Object, object_path: ObjectPath) -> Optional[ExprDoc]:
"""
Get the doc for a variable.
Parameters
----------
obj : Object
The variable object.
object_path : ObjectPath
The object path to be associated with the returned ExprDoc.
Returns
-------
doc : ExprDoc
The doc for this variable.
"""
return _ffi_api.VarTableGetVarDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member

def is_var_defined(self, obj: Object) -> bool:
"""
Check whether a variable is defined.
Parameters
----------
obj : Object
The variable object.
Returns
-------
is_defined : bool
Whether the variable is defined.
"""
return _ffi_api.VarTableIsVarDefined(self, obj) # type: ignore # pylint: disable=no-member

def __contains__(self, obj: Object) -> bool:
return self.is_var_defined(obj)
108 changes: 108 additions & 0 deletions src/script/printer/var_table.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/node/object_path.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/var_table.h>

namespace tvm {
namespace script {
namespace printer {

String GenerateUniqueName(const String& name_hint, std::unordered_set<String>* defined_names) {
String name = name_hint;
for (int i = 1; !defined_names->insert(name).second; ++i) {
name = name_hint + "_" + std::to_string(i);
}
return name;
}

IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint,
const ObjectPath& object_path, const Frame& frame) {
String name = GenerateUniqueName(name_hint, &this->defined_names);
DocFactory doc_factory = [name]() { return IdDoc(name); };

auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
ICHECK(result.second) << "Duplicated object: " << obj;

IdDoc def_doc(name);
def_doc->source_paths.push_back(object_path);

frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });

return def_doc;
}

void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame) {
ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;

ICHECK(!doc_factory()->IsInstance<IdDocNode>())
<< "VarTableNode::Define cannot be used for variable that's mapped to IdDoc.";

obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});

frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
}

Optional<ExprDoc> VarTableNode::GetVarDoc(const ObjectRef& obj,
const ObjectPath& object_path) const {
auto it = obj2info.find(obj);
if (it == obj2info.end()) {
return NullOpt;
}
ExprDoc doc = it->second.doc_factory();
doc->source_paths.push_back(object_path);
return doc;
}

bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }

void VarTableNode::RemoveVar(const ObjectRef& obj) {
auto it = obj2info.find(obj);
ICHECK(it != obj2info.end()) << "No such object: " << obj;

if (it->second.name.defined()) {
defined_names.erase(it->second.name.value());
}
obj2info.erase(it);
}

VarTable::VarTable() { data_ = make_object<VarTableNode>(); }

TVM_REGISTER_NODE_TYPE(VarTableNode);
TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return VarTable(); });
TVM_REGISTER_GLOBAL("script.printer.VarTableDefine")
.set_body_method<VarTable, VarTableNode, IdDoc, const ObjectRef&, const String&,
const ObjectPath&, const Frame&>(&VarTableNode::Define);
TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc")
.set_body_typed([](VarTable var_table, const ObjectRef& obj, runtime::PackedFunc factory,
Frame frame) {
var_table->DefineByDoc(
obj, [f = std::move(factory)]() { return f(); }, frame);
});
TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc")
.set_body_method<VarTable>(&VarTableNode::GetVarDoc);
TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined")
.set_body_method<VarTable>(&VarTableNode::IsVarDefined);

} // namespace printer
} // namespace script
} // namespace tvm
Loading

0 comments on commit d33a332

Please sign in to comment.