Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Optionally output the address as part of variable names #15579

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class PrinterConfigNode : public Object {
int num_context_lines = -1;
/*! \brief Whether to output with syntax sugar, set false for complete printing. */
bool syntax_sugar = true;
/*! \brief Whether variable names should include the object's address */
bool show_object_address = false;
/* \brief Object path to be underlined */
Array<ObjectPath> path_to_underline = Array<ObjectPath>();
/*! \brief Object path to be annotated. */
Expand All @@ -91,6 +93,7 @@ class PrinterConfigNode : public Object {
v->Visit("print_line_numbers", &print_line_numbers);
v->Visit("num_context_lines", &num_context_lines);
v->Visit("syntax_sugar", &syntax_sugar);
v->Visit("show_object_address", &show_object_address);
v->Visit("path_to_underline", &path_to_underline);
v->Visit("path_to_annotate", &path_to_annotate);
v->Visit("obj_to_underline", &obj_to_underline);
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/runtime/script_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class PrinterConfig(Object):
print_line_numbers: bool
num_context_lines: int
syntax_sugar: bool
show_object_address: bool
path_to_underline: Optional[List[ObjectPath]]
path_to_annotate: Optional[Dict[ObjectPath, str]]
obj_to_underline: Optional[List[Object]]
Expand All @@ -60,6 +61,7 @@ def __init__(
print_line_numbers: bool = False,
num_context_lines: Optional[int] = None,
syntax_sugar: bool = True,
show_object_address: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
Expand All @@ -79,6 +81,7 @@ def __init__(
"print_line_numbers": print_line_numbers,
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"show_object_address": show_object_address,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
Expand Down Expand Up @@ -119,6 +122,7 @@ def script(
print_line_numbers: bool = False,
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = False,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
Expand Down Expand Up @@ -153,6 +157,8 @@ def script(
The number of lines of context to print before and after the line to underline.
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.
show_object_address: bool = False
Whether to include the object's adddress as part of the TVMScript name
path_to_underline : Optional[List[ObjectPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
Expand Down Expand Up @@ -182,6 +188,7 @@ def script(
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
Expand All @@ -206,6 +213,7 @@ def show(
print_line_numbers: bool = False,
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
Expand Down Expand Up @@ -245,6 +253,8 @@ def show(
The number of lines of context to print before and after the line to underline.
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.
show_object_address: bool = False
Whether to include the object's adddress as part of the TVMScript name
path_to_underline : Optional[List[ObjectPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
Expand Down Expand Up @@ -272,6 +282,7 @@ def show(
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
Expand Down
4 changes: 4 additions & 0 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
if (auto v = config_dict.Get("syntax_sugar")) {
n->syntax_sugar = Downcast<IntImm>(v)->value;
}
if (auto v = config_dict.Get("show_object_address")) {
n->show_object_address = Downcast<IntImm>(v)->value;
}

this->data_ = std::move(n);
}

Expand Down
10 changes: 9 additions & 1 deletion src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/ir_docsifier.h>

#include <sstream>

#include "./utils.h"

namespace tvm {
Expand All @@ -29,7 +31,13 @@ namespace printer {

IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) {
ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
String name = GenerateUniqueName(name_hint, this->defined_names);
String name = name_hint;
if (cfg->show_object_address) {
std::stringstream stream;
stream << name << "_" << obj.get();
name = stream.str();
}
name = GenerateUniqueName(name, this->defined_names);
this->defined_names.insert(name);
DocCreator doc_factory = [name]() { return IdDoc(name); };
obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring

import re

import tvm.testing
from tvm import ir, tir
from tvm.ir import Range
Expand Down Expand Up @@ -798,5 +801,43 @@ def main():
_assert_print(root_block_explicitly, expected_output)


def test_variable_with_cpp_address():
"""The show_object_address option displays the C++ addressess

Because the C++ address may vary with each execution, the output
produced with this option cannot be compared to a fixed string.
Instead, this test uses the normal script output to generate a
regular expression against with the test output must match. The
regular expression validates that all names have been appended
with "_0x" followed by a hexadecimal number, and that the address
is the same for each variable.
"""
from tvm.script import tir as T

# The test function has all named objects suffixed with "_name",
# to avoid spurious replacement when generating the expected
# regex.
@T.prim_func
def func(a_name: T.handle):
N_name = T.int64()
A_name = T.match_buffer(a_name, N_name, "float32")
for i_name in range(N_name):
A_name[i_name] = A_name[i_name] + 1.0

without_address = func.script(show_object_address=False)
script = func.script(show_object_address=True)

expected_regex = re.escape(without_address)
for name in ["a_name", "A_name", "N_name", "i_name"]:
# Replace all occurrences with a backref to an earlier match
expected_regex = expected_regex.replace(name, rf"(?P={name})")
# Then replace the first such backref with a capturing group.
expected_regex = expected_regex.replace(
rf"(?P={name})", rf"(?P<{name}>{name}_0x[A-Fa-f0-9]+)", 1
)

assert re.match(expected_regex, script)


if __name__ == "__main__":
tvm.testing.main()