Skip to content

Commit

Permalink
[TVMScript] Printer entry point (#12462)
Browse files Browse the repository at this point in the history
This PR:

- Adds an entry point for the TVMScript Unified Printer
- Adds a helper object class `RootNodeContainer` to provide an injection point for the actual printer implementation to add specialized logic on the root node to print.

Tracking issue: #11912
  • Loading branch information
yelite authored Aug 21, 2022
1 parent 92355f2 commit cc769fd
Show file tree
Hide file tree
Showing 11 changed files with 355 additions and 3 deletions.
56 changes: 56 additions & 0 deletions include/tvm/script/printer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_SCRIPT_PRINTER_H_
#define TVM_SCRIPT_PRINTER_H_

#include <tvm/node/node.h>
#include <tvm/node/object_path.h>

namespace tvm {
namespace script {
namespace printer {

/*!
* \brief Print IR graph as TVMScript code
*
* \param root_node The root node to print.
* \param ir_name The dispatch token of the target IR, e.g., "tir", "relax".
* \param ir_prefix The symbol name for TVMScript IR namespaces. For example, {"tir": "T"}.
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
*
* \return the TVMScript code as string.
*/
String Script( //
const ObjectRef& root_node, //
String ir_name, //
Map<String, String> ir_prefix, //
int indent_spaces = 4, //
bool print_line_numbers = false, //
int num_context_lines = -1, //
Optional<ObjectPath> path_to_underline = NullOpt //
);

} // namespace printer
} // namespace script
} // namespace tvm

#endif // TVM_SCRIPT_PRINTER_H_
6 changes: 6 additions & 0 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class ExprDoc : public Doc {
ExprDoc() = default;

public:
/*!
* \brief Create a doc representing index access on the current ExprDoc
* \param indices The indices to access.
*/
ExprDoc operator[](Array<Doc> indices) const;

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
};

Expand Down
41 changes: 41 additions & 0 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,47 @@ class IRDocsifier : public ObjectRef {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode);
};

/*!
* \brief A wrapper object to provide injection point for printer of each IR.
*
* For any IR node to be transformed by IRDocsifier, it will be wrapped by RootNodeContainer
* and be dispatched to the corresponding function first. This provides an injection point for
* each IR's printer implemention to add specialized logic, for example, pushing a special
* Frame to the IRDocsifier before doing any IR->Doc transformation.
*
* \code
* TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
* .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
* const ObjectRef& root_node = obj.Get()->root_node;
* // For example, relax printer can create a Frame specialized to Relax here
* RelaxGeneralFrame frame;
* auto ctx = p->WithFrame(frame);
* // More specialized logic for your IR.
* return p->AsDoc<Doc>(MakeTraced(root_node));
* });
* \endcode
*/
class RootNodeContainerNode : public Object {
public:
/*! \brief The root node to print. */
ObjectRef root_node;

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("root_node", &root_node); }

static constexpr const char* _type_key = "script.printer.RootNodeContainer";
TVM_DECLARE_FINAL_OBJECT_INFO(RootNodeContainerNode, Object);
};

class RootNodeContainer : public ObjectRef {
public:
/*!
* \brief Constructor of RootNodeContainer.
* \param root_node The root node to print.
* */
explicit RootNodeContainer(ObjectRef root_node);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootNodeContainer, ObjectRef, RootNodeContainerNode);
};

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/script/printer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
"""

from . import _ffi_api
from .entry import script
71 changes: 71 additions & 0 deletions python/tvm/script/printer/entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This file contains the entry point of TVMScript Unified Printer.
"""

from typing import Dict, Optional

from tvm.runtime import Object, ObjectPath

from . import _ffi_api


def script( # pylint: disable=too-many-arguments
root_node: Object,
ir_name: str,
ir_prefix: Dict[str, str],
indent_spaces: int = 4,
print_line_numbers: bool = False,
num_context_lines: int = -1,
path_to_underline: Optional[ObjectPath] = None,
) -> str:
"""
Print IR graph as TVMScript code
Parameters
----------
root_node : Object
The root node to print.
ir_name : str
The dispatch token of the target IR, e.g., "tir", "relax".
ir_prefix : Dict[str, str]
The symbol name for TVMScript IR namespaces. For example,
{"tir": "T"}.
indent_spaces : int
The number of indent spaces to use in the output
print_line_numbers: bool
Whether to print line numbers
num_context_lines : Optional[int]
Number of context lines to print around the underlined text
path_to_underline : Optional[ObjectPath]
Object path to be underlined
Returns
-------
script : str
The TVMScript code of the root_node
"""
return _ffi_api.Script( # type: ignore # pylint: disable=no-member
root_node,
ir_name,
ir_prefix,
indent_spaces,
print_line_numbers,
num_context_lines,
path_to_underline,
)
51 changes: 49 additions & 2 deletions python/tvm/script/printer/ir_docsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ def _ensure_cleanup_function_registered():
_CLEANUP_REGISTERED = True


@register_object("script.printer.RootNodeContainer")
class RootNodeContainer(Object):
"""
A wrapper object to provide injection point for printer of each IR.
This class shouldn't be used directly. `IRDocsifier.set_root_dispatch`
should be used instead.
"""

root_node: Object

def __init__(self, root_node: Object):
self.__init_handle_by_constructor__(_ffi_api.RootNodeContainer, root_node) # type: ignore # pylint: disable=no-member


@register_object("script.printer.IRDocsifier")
class IRDocsifier(Object):
"""
Expand Down Expand Up @@ -91,7 +106,7 @@ def __init__(self, ir_prefix: Dict[str, str]):
def set_dispatch(
cls,
node_type: Type[_TObject],
dispatch_function: Callable[[_TObject, "IRDocsifier"], Doc],
dispatch_function: Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc],
dispatch_token: str = "",
) -> None:
"""
Expand All @@ -101,7 +116,7 @@ def set_dispatch(
----------
node_type : Type[_TObject]
The type of object to dispatch on.
dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
dispatch_function : Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc]
The dispatch function. It's called to transform IR node object to Doc.
dispatch_token : str
Function will only be called when this dispatch_token is the same as the one
Expand All @@ -119,6 +134,38 @@ def set_dispatch(
)
_REGISTERED_TYPES.add((dispatch_token, type_index))

@classmethod
def set_root_dispatch(
cls, dispatch_token: str, root_dispatch_function: Callable[[Object, "IRDocsifier"], Doc]
) -> None:
"""
Set the root dispatch function for an IR.
The root dispatch function will be called with the root node of an IR graph
that's being transformed to Doc. This provides an injection point for
each IR's printer implemention to add specialized logic, for example,
pushing a special Frame to the IRDocsifier before doing actual IR->Doc
transformation.
The simplest root dispatch function is
```
def f(obj, ir_docsifier)
return ir_docsifier.as_doc(obj, ObjectPath.root())
```
Parameters
----------
root_dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
The root dispatch function. It's called with the root node to be printed.
dispatch_token : str
The dispatch token of the IR that root_dispatch_funnction applies to.
"""

def dispatch_function(obj: RootNodeContainer, _, ir_docsifier):
return root_dispatch_function(obj.root_node, ir_docsifier)

cls.set_dispatch(RootNodeContainer, dispatch_function, dispatch_token)

def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc:
"""
Transform the input object into Doc.
Expand Down
54 changes: 54 additions & 0 deletions src/script/printer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/runtime/registry.h>
#include <tvm/script/printer.h>
#include <tvm/script/printer/doc.h>
#include <tvm/script/printer/doc_printer.h>
#include <tvm/script/printer/frame.h>
#include <tvm/script/printer/ir_docsifier.h>

namespace tvm {
namespace script {
namespace printer {

String Script( //
const ObjectRef& root_node, //
String ir_name, //
Map<String, String> ir_prefix, //
int indent_spaces, //
bool print_line_numbers, //
int num_context_lines, //
Optional<ObjectPath> path_to_underline //
) {
IRDocsifier ir_docsifier(ir_prefix);

auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name);

Doc doc = ir_docsifier->AsDoc<Doc>(MakeTraced(RootNodeContainer(root_node)));

return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
path_to_underline);
}

TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(&Script);

} // namespace printer
} // namespace script
} // namespace tvm
2 changes: 2 additions & 0 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_
return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
}

ExprDoc ExprDoc::operator[](Array<Doc> indices) const { return (*get())[indices]; }

StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>();
n->stmts = stmts;
Expand Down
32 changes: 32 additions & 0 deletions src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/runtime/container/base.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/ir_docsifier.h>
#include <tvm/script/printer/traced_object.h>
Expand All @@ -42,6 +43,31 @@ IRDocsifier::FType& IRDocsifier::vtable() {
return inst;
}

RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
auto n = make_object<RootNodeContainerNode>();
n->root_node = std::move(root_node);
data_ = std::move(n);
}

// Add a default dispatch for the RootNodeContainer to throw error.
// To add implementation for a new IR, RootNodeContainer needs to be
// registered under the dispatch token of that IR, like:
// \code
// TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
// .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
// const ObjectRef& root_node = obj.Get()->root_node;
// \\ More specialized logic for your IR.
// return p->AsDoc<Doc>(MakeTraced(root_node));
// });
// \endcode
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
String top_dispatch_token = p->dispatch_tokens.back();
ICHECK_NE(top_dispatch_token, "");
ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
throw;
});

TVM_REGISTER_NODE_TYPE(IRDocsifierNode);
TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map<String, String> ir_prefix) {
return IRDocsifier(ir_prefix);
Expand Down Expand Up @@ -71,6 +97,12 @@ TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch")
.set_body_typed([](String token, uint64_t type_index) {
IRDocsifier::vtable().remove_dispatch(token, type_index);
});

TVM_REGISTER_NODE_TYPE(RootNodeContainerNode);
TVM_REGISTER_GLOBAL("script.printer.RootNodeContainer").set_body_typed([](ObjectRef root_node) {
return RootNodeContainer(root_node);
});

} // namespace printer
} // namespace script
} // namespace tvm
Loading

0 comments on commit cc769fd

Please sign in to comment.