diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7313b4f78349..7ae5adfe4055 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -95,6 +95,12 @@ class IRModuleNode : public Object { return GetAttr(attr_key, Optional(default_value)); } + /*! + * \brief Get the metadata attributes. + * \returns The additional meta-data attributes + */ + DictAttrs GetAttrs() const { return attrs; } + /*! * \brief Check whether the module has an non-zero integer attr. * @@ -357,7 +363,7 @@ class IRModule : public ObjectRef { * \param type_definitions Type definitions in the module. * \param import_set Set of imported files in the module. * \param map The module source map. - * \param attrs The module attributes. + * \param attrs The module meta-data attributes. */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 1017d3ebc368..cd7f4c2cbede 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -107,7 +107,6 @@ class VarNode : public ExprNode { bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) && - // Do we use the analysis information in equality? equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); } diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 00f45f2591d4..b1cd542121f6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -85,6 +85,7 @@ class ConstantNode : public ExprNode { v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); + v->Visit("shape_", &shape_); } bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const { diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 5b26d5e4fb79..9ce08f274cce 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -248,7 +248,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): The left operand. rhs : Object - The left operand. + The right operand. map_free_vars : bool Whether or not shall we map free vars that does diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 9512b631cd4b..df3f94f1fb1a 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -16,7 +16,7 @@ # under the License. """IRModule that holds the functions and type definitions.""" from typing import Optional - +import ast from tvm._ffi.base import string_types import tvm._ffi @@ -39,7 +39,7 @@ class IRModule(Node): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -62,7 +62,17 @@ def __init__(self, functions=None, type_definitions=None): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + ) def __setitem__(self, var, val): """Add a mapping to the module. @@ -308,6 +318,17 @@ def get_attr(self, attr_key): return _ffi_api.Module_GetAttr(self, attr_key) + def get_attrs(self): + """Get the meta_data attributes. + + Returns + ------- + meta_data : DictAttrs + meta_data attributes + """ + + return _ffi_api.Module_GetAttrs(self) + def with_attr(self, attr_key, attr_value): """Copy the IRModule and add an attribute to it. diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index e6c102545aa6..90ab7bfb3649 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -31,6 +31,7 @@ Call = relay.Call If = relay.If const = relay.const +Constant = relay.Constant @tvm._ffi.register_object("relax.expr.ShapeExpr") diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 137f70a06ca5..8d0f61b39e43 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -19,8 +19,20 @@ from tvm.tir import PrimFunc from tvm import IRModule -# Simply extracts tir PrimFuncs from the input IRModule + def tir_partitioner(mod: IRModule) -> List[IRModule]: + """Extracts tir PrimFuncs from the input IRModule. + + Parameters + ---------- + mod : IRModule + The input IRModule. + + Returns + ------- + output : List[IRModule] + The result tir PrimFuncs. + """ partitions = [] for gvar in mod.get_global_vars(): if isinstance(mod[gvar], PrimFunc): @@ -28,3 +40,45 @@ def tir_partitioner(mod: IRModule) -> List[IRModule]: tir_mod[gvar] = mod[gvar] partitions.append(tir_mod) return partitions + + +def metadata_partitioner(rx_txt: str) -> List[str]: + """Extract Relax program and metadata section. + + Parameters + ---------- + rx_txt : str + The input relax text. + + Returns + ------- + output : List[str] + The result list of partitioned text, the first element + is the relax program, and the second is metadata section. + """ + partitions = [] + left_curly = 0 + meta_start = 0 + meta_end = 0 + for i, char in enumerate(rx_txt): + if i < 0: + raise ValueError("The program is invalid.") + if char == "{": + if meta_start == 0: + meta_start = i + left_curly += 1 + elif char == "}": + left_curly -= 1 + if left_curly == 0: + meta_end = i + 1 + break + + if meta_end == 0: + raise ValueError("The metadata section was not found.") + metadata = rx_txt[meta_start:meta_end] + rx_program = rx_txt[meta_end:-1] + + partitions.append(rx_program) + partitions.append(metadata) + + return partitions diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index c734efca3007..36db57118f69 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -20,11 +20,12 @@ different python versions. Synr also provides an error handling context that we use for error reporting. """ -# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except +# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except, import-outside-toplevel import types import json import operator import inspect +import functools from typing import Any, Callable, Dict, List, Optional, Union from synr import ast, Transformer, to_ast @@ -1364,7 +1365,7 @@ def from_source( raise TypeError("Only function definitions are supported.") -def ir_module(input_module: type) -> IRModule: +def ir_module(input_module=None, metadata=None) -> IRModule: """Decorate a python class as tvm IRModule. Parameters @@ -1372,18 +1373,32 @@ def ir_module(input_module: type) -> IRModule: input_module : type The python class to be parsed. + metadata : Optional[Union[str, DictAttrs]] + The metadata attributes to be parsed. + Returns ------- - output : IRModule + mod : IRModule The result IRModule. """ - if inspect.isclass(input_module): - func_dict = { - name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) - } - mod = IRModule(func_dict) - mod = relax.transform.ResolveGlobals()(mod) - # FIXME(@altanh): where is the source map? - return mod - - raise TypeError("Only class definitions are supported.") + if metadata is not None: + from .relax.parser import RelaxTransformer as _RelaxTransformer + + _RelaxTransformer.update_meta(metadata) + + if input_module is None: + return functools.partial(ir_module, metadata=metadata) + + def _ir_module(input_module: type) -> IRModule: + if inspect.isclass(input_module): + func_dict = { + name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) + } + mod = IRModule(func_dict, attrs=metadata) + mod = relax.transform.ResolveGlobals()(mod) + # FIXME(@altanh): where is the source map? + return mod + + raise TypeError("Only class definitions are supported.") + + return _ir_module(input_module) diff --git a/python/tvm/script/relax/function.py b/python/tvm/script/relax/function.py index 2ffffc1ae3e1..da861940b130 100644 --- a/python/tvm/script/relax/function.py +++ b/python/tvm/script/relax/function.py @@ -15,16 +15,18 @@ # specific language governing permissions and limitations # under the License. """TVM Script Interface for Relax Functions""" +# pylint: disable=import-outside-toplevel import inspect from typing import Callable +import functools from tvm.relax import Function from .parser import from_source -def function(input_func: Callable) -> Function: +def function(input_func=None, metadata=None) -> Function: """Decorate a Python function as a Relax function in TVM script. Parameters @@ -32,15 +34,28 @@ def function(input_func: Callable) -> Function: input_func : Callable The function to be parsed. + metadata : Optional[Union[str, DictAttrs]] + The meta_data attributes to be parsed. + Returns ------- output : Function The parsed Relax Function. """ - if inspect.isfunction(input_func): - result = from_source(input_func) - result.__name__ = input_func.__name__ - result.__qualname__ = input_func.__qualname__ - return result + if metadata is not None: + from .parser import RelaxTransformer as _RelaxTransformer + + _RelaxTransformer.update_meta(metadata) + + if input_func is None: + return functools.partial(function, metadata=metadata) + + def _function(input_func: Callable) -> Function: + if inspect.isfunction(input_func): + result = from_source(input_func) + result.__name__ = input_func.__name__ + result.__qualname__ = input_func.__qualname__ + return result + raise TypeError("Only function definitions are supported.") - raise TypeError("Only function definitions are supported.") + return _function(input_func) diff --git a/python/tvm/script/relax/parser.py b/python/tvm/script/relax/parser.py index 8522a6b9e79d..e2a1d9fe2d8e 100644 --- a/python/tvm/script/relax/parser.py +++ b/python/tvm/script/relax/parser.py @@ -20,18 +20,19 @@ from __future__ import annotations import inspect +import json from enum import Enum from typing import Union, Dict, List, Tuple, Optional, Callable, Any import tvm from tvm import relay, relax, tir +from tvm.relax.utils import metadata_partitioner import tvm.script from tvm.ir import diagnostics from tvm.ir.module import IRModule from tvm.script.tir.node import BufferSlice import tvm.script.tir as tir_namespace import tvm.script.relax as relax_namespace - import synr from synr import ast, Transformer @@ -70,6 +71,8 @@ class SpecialOp(Enum): DATAFLOW_OUTPUT = "relax.output" TUPLE = "relax.Tuple" TUPLE_GET_ITEM = "relax.TupleGetItem" + CONST = "relax.const" + CONSTANT = "relax.expr.Constant" class ArithmeticOp(Enum): @@ -102,6 +105,8 @@ class ArithmeticOp(Enum): class RelaxTransformer(Transformer): """A visitor to handle transformations on the Relax AST""" + meta_attr = None + def __init__(self, ir_mod: IRModule, relax_prefix: List[str], tir_prefix: List[str]): super().__init__() self.module = ir_mod @@ -160,6 +165,29 @@ def __exit__(self, *exc): return _Scope(self) + @classmethod + def update_meta(cls, metadata: str): + """Update the metadata attributes. + + Parameters + ---------- + metadata : str + The metadata to be parsed. + """ + + cls.meta_attr = metadata + + @classmethod + def get_meta(cls) -> str: + """Return the metadata attribute. + + Returns + ------- + str: + The metadata attributes + """ + return cls.meta_attr + @property def scope(self): """Returns the current definition scope. @@ -482,11 +510,16 @@ def transform_function(self, func: ast.Function, is_global: bool = False) -> rel self.report_error( "functions must be decorated as a Relax Function or TIR PrimFunc", func.span ) + decorator_name = None + if isinstance(func.decorators[0], ast.Call): + decorator_name = self._parse_attrs_to_str(func.decorators[0].func_name) + else: + decorator_name = self._parse_attrs_to_str(func.decorators[0]) - if self._parse_attrs_to_str(func.decorators[0]) == "tir.prim_func": + if decorator_name == "tir.prim_func": return self._tir_from_synr(func) - if self._parse_attrs_to_str(func.decorators[0]) != "relax.function": + if decorator_name != "relax.function": self.report_error( "functions must be decorated as a Relax Function or TIR PrimFunc", func.span ) @@ -612,7 +645,10 @@ def parse_var_binding(self, stmt: ast.Assign, is_dataflow=False) -> relax.VarBin The parsed Relax variable binding """ var = self._get_lhs(stmt) - rhs = self.transform_expr(stmt.rhs) + if isinstance(stmt.rhs, ast.Constant): + rhs = relax.const(stmt.rhs.value) + else: + rhs = self.transform_expr(stmt.rhs) # an ExternFunc call comes from call_packed bind_free_vars = isinstance(rhs, relay.Call) and isinstance(rhs.op, relax.ExternFunc) ty, shape = self.transform_type(stmt.ty, bind_free_vars) @@ -645,7 +681,6 @@ def transform_stmt( if isinstance(stmt, ast.Assign): # dataflow bindings are handled separately in parse_dataflow return self.parse_binding(stmt) - elif isinstance(stmt, ast.If): # check branches are non-empty if len(stmt.true.stmts) == 0 or len(stmt.false.stmts) == 0: @@ -853,6 +888,42 @@ def parse_attr(self, expr: ast.Attr) -> relax.Expr: # TODO(@altanh): maybe diagnostics here in case this fails? return relay.op.get(op_name) + def parse_array_literal( + self, expr: ast.ArrayLiteral + ) -> Union[relax.const, relax.expr.Constant]: + """Parses the given synr ArrayLiteral node to a Relax constant. + + Parameters + ---------- + expr : ast.ArrayLiteral + The synr ArrayLiteral to be parsed. + + Returns + ------- + Union[relax.const, relax.expr.Constant] + The parsed relex expression. + """ + + def _get_values(expr: ast.ArrayLiteral, vals: List[Any]) -> List[Any]: + # todo(@yongwww): the generic parsing util for ArrayLiteral should be in synr + if isinstance(expr, ast.Constant): + vals.append(expr.value) + elif isinstance(expr, ast.ArrayLiteral): + for elem in expr.values: + # recursive call to get the nested list + nested_vals = _get_values(elem, []) + # avoid nested list for every element + if len(nested_vals) == 1 and not isinstance(nested_vals[0], list): + vals.append(nested_vals[0]) + else: + vals.append(nested_vals) + else: + self.report_error(f"unsupported ast expression {expr.name}", expr.span) + return vals + + const_values = _get_values(expr, []) + return relax.const(const_values) + def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]: """Parses the given synr Call node to a Relax expression or PrimExpr. @@ -868,7 +939,39 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]: PrimExprs. """ if isinstance(expr.func_name, ast.Op) and expr.func_name.name == ast.BuiltinOp.Subscript: - return self.transform_Subscript(expr) + if ( + hasattr(expr.params[0], "params") + and hasattr(expr.params[0].params[0], "id") + and expr.params[0].params[0].id.name == "meta" + ): + # Get the index of constant in b64ndarrays in metadata + const_idx = 0 + if hasattr(expr.params[-1], "values"): + const_idx = expr.params[-1].values[0].value + + if self.module.get_attrs(): + metadata = self.module.get_attrs() + else: + metadata = RelaxTransformer.get_meta() + + if not metadata: + self.report_error( + f"metadata is not found, please feed it into ir_module", expr.span + ) + + attr_json = json.loads(str(metadata)) + new_root = const_num = 0 + for i, node in enumerate(attr_json["nodes"]): + if "type_key" in node and "Constant" in node["type_key"]: + if const_num == const_idx: + new_root = i + break + const_num += 1 + attr_json["root"] = new_root + return tvm.ir.load_json(json.dumps(attr_json)) + else: + return self.transform_Subscript(expr) + op = self.transform_expr(expr.func_name) if op == SpecialOp.CALL_PACKED: @@ -890,6 +993,15 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]: args = [self.transform_expr(arg) for arg in expr.params] # index of TupleGetItem only accepts int type intead of tir.expr.IntImm return relax.TupleGetItem(args[0], args[1].value) + elif op in (SpecialOp.CONSTANT, SpecialOp.CONST): + # relax const/Constant + arg = expr.params[0] + if isinstance(arg, ast.Constant): + return relax.const(arg.value) + elif isinstance(arg, ast.ArrayLiteral): + return self.parse_array_literal(arg) + else: + self.report_error(f"unsupported ast for const: {arg}", expr.span) elif isinstance(op, ArithmeticOp): args = [self.transform_expr(arg) for arg in expr.params] @@ -942,11 +1054,10 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]: attrs = None if kwargs or not is_default: attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) - return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span)) # Exprs: - # - ArrayLiteral: unsupported for now? + # - ArrayLiteral # - Attr: use for .shape, and intrinsic/special operator namespace # - Call # - Constant @@ -1010,10 +1121,10 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr: elif expr.value is None: return None else: - self.report_error( - f"unsupported constant expression: {expr}", - expr.span, - ) + return relax.const(expr.value) + + elif isinstance(expr, ast.ArrayLiteral): + return self.parse_array_literal(expr) elif isinstance(expr, ast.Op): # TODO(@altanh): might need to generalize from ArithmeticOp if we decide to support @@ -1022,7 +1133,6 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr: return ArithmeticOp(expr.name) except ValueError: self.report_error(f"unsupported built-in operator: {expr.name}", expr.span) - else: self.report_error(f"unsupported expression: {expr}", expr.span) @@ -1243,7 +1353,7 @@ def from_source( Parameters ---------- - input_module : Union[str, Callable] + input_func : Union[str, Callable] The python function to be parsed. relax_prefix : Optional[List[str]] @@ -1254,10 +1364,14 @@ def from_source( Returns ------- - output : Union[Function, Module] - The Function or Module in IR. + output : Union[Function, IRModule] + The relax Function or IRModule. """ - mod = IRModule() + metadata = None + if isinstance(input_func, str) and "b64ndarrays" in input_func: + input_func, metadata = metadata_partitioner(input_func) + + mod = IRModule(attrs=metadata) if isinstance(input_func, str): relax_prefix = ["R", "relax"] if relax_prefix is None else relax_prefix tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix @@ -1301,19 +1415,23 @@ def from_source( # return mod -def pretty_print(node): +def pretty_print(node, show_meta_data=False): """Prints the given Relax IR node in the Relax text format. Parameters ---------- node : Union[relax.Type, relax.Expr, relax.Binding, relax.BindingBlock] The Relax IR node to print. + + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. """ - print(tvm.script._ffi_api.AsRelaxScript(node)) + print(tvm.script._ffi_api.AsRelaxScript(node, show_meta_data)) # TODO(@altanh): printer stuff should probably live elsewhere? -def astext(node) -> str: +def astext(node, show_meta_data=False) -> str: """Returns the Relax text format representation of the given Relax IR node. Parameters @@ -1321,9 +1439,15 @@ def astext(node) -> str: node : Union[relax.Type, relax.Expr, relax.Binding, relax.BindingBlock] The Relax IR node to print. + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + Returns ------- - str + relax_text: str The text format representation of the given Relax IR node. + If show_meta_data is True, the meta data section will be printed in the beginning + of the the return string. """ - return tvm.script._ffi_api.AsRelaxScript(node) + return tvm.script._ffi_api.AsRelaxScript(node, show_meta_data) diff --git a/src/ir/module.cc b/src/ir/module.cc index 8d6de5a536a7..c82388cf516f 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -436,10 +436,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); - }); + .set_body_typed([](tvm::Map funcs, tvm::Map types, + tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) { IRModule mod = args[0]; @@ -531,6 +529,10 @@ TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String return mod->GetAttr(key); }); +TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { + return mod->GetAttrs(); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index b7460c9c0e5f..bfa97175da27 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -35,110 +35,6 @@ namespace tvm { namespace relax { -class RelaxScriptPrinter : public relax::IRFunctor, - public tir::ExprFunctor, - public TypeFunctor, - public AttrFunctor { - public: - TVM_DLL Doc Print(const ObjectRef& node); - - private: - NameTable name_table_; - int constant_counter_; - std::unordered_map var_id_map_; - std::unordered_map dim_var_map_; - - // IR nodes inherited from Relay - Doc VisitNode_(const relay::ConstantNode* op) override; - Doc VisitNode_(const relay::TupleNode* op) override; - Doc VisitNode_(const relay::GlobalVarNode* op) override; - Doc VisitNode_(const relay::CallNode* op) override; - // Doc VisitNode_(const relay::IfNode* op) override; - Doc VisitNode_(const OpNode* op) override; - Doc VisitNode_(const relay::TupleGetItemNode* op) override; - - // IR nodes introduced by Relax - Doc VisitNode_(const relax::VarNode* op) override; - Doc VisitNode_(const relax::DataflowVarNode* op) override; - Doc VisitNode_(const relax::ShapeExprNode* op) override; - Doc VisitNode_(const relax::MatchShapeNode* op) override; - Doc VisitNode_(const relax::VarBindingNode* op) override; - Doc VisitNode_(const relax::BindingBlockNode* op) override; - Doc VisitNode_(const relax::DataflowBlockNode* op) override; - Doc VisitNode_(const relax::SeqExprNode* op) override; - Doc VisitNode_(const relax::FunctionNode* op) override; - Doc VisitNode_(const relax::ExternFuncNode* op) override; - - // PrimExpr nodes allowed in Relax - Doc VisitExpr_(const tir::VarNode* op) override; - Doc VisitExpr_(const tir::IntImmNode* op) override; - Doc VisitExpr_(const tir::AddNode* op) override; - Doc VisitExpr_(const tir::SubNode* op) override; - Doc VisitExpr_(const tir::MulNode* op) override; - Doc VisitExpr_(const tir::DivNode* op) override; - Doc VisitExpr_(const tir::FloorDivNode* op) override; - - Doc PrintIRModule(const IRModule& mod); - Doc PrintPrimFunc(const String& name, const tir::PrimFunc& func); - - Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); - Doc PrintFunctionDef(const Doc& name, const relax::Function& func); - - Doc PrintVarAnnotation(const relax::Var& var); - Doc PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape); - - Doc VisitType_(const relax::ShapeTypeNode* node) override; - Doc VisitType_(const relax::DynTensorTypeNode* node) override; - Doc VisitType_(const relay::TupleTypeNode* node) override; - - Doc PrintAttr(const ObjectRef& attr); - std::vector PrintAttrs(const Attrs& attrs); - Doc VisitAttrDefault_(const Object* op) override; - Doc VisitAttr_(const ArrayNode* op) override; - Doc VisitAttr_(const tir::IntImmNode* op) override; - Doc VisitAttr_(const tir::FloatImmNode* op) override; - - Doc GetUniqueName(std::string prefix, std::string fallback); - - /*! - * \brief Attribute printer which prints the attributes as kwargs in a call. - */ - class AttrPrinter : public AttrVisitor { - public: - AttrPrinter(std::vector* docs, RelaxScriptPrinter* parent) : docs(docs), parent_(parent) {} - - template - void PrintKV(const char* key, const T& value) { - Doc doc; - doc << key << "=" << value; - docs->push_back(doc); - } - - void Visit(const char* key, double* value) final { PrintKV(key, *value); } - void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } - void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } - void Visit(const char* key, int* value) final { PrintKV(key, *value); } - void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } - void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "do not allow void as argument"; - } - void Visit(const char* key, DataType* value) final { - PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); - } - void Visit(const char* key, runtime::NDArray* value) final { - LOG(FATAL) << "do not allow NDarray as argument"; - } - void Visit(const char* key, runtime::ObjectRef* obj) final { - PrintKV(key, parent_->PrintAttr(*obj)); - } - - private: - std::vector* docs; - RelaxScriptPrinter* parent_; - }; -}; - Doc RelaxScriptPrinter::Print(const ObjectRef& node) { if (node->IsInstance()) { return PrintIRModule(Downcast(node)); @@ -216,12 +112,6 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::TupleGetItemNode* op) { return doc; } -Doc RelaxScriptPrinter::VisitNode_(const relay::ConstantNode* op) { - Doc doc; - doc << "meta[relax.Constant][" << constant_counter_++ << "]"; - return doc; -} - Doc RelaxScriptPrinter::VisitNode_(const relax::VarNode* op) { if (!var_id_map_.count(op->vid)) { var_id_map_[op->vid] = GetUniqueName(op->name_hint(), "v"); @@ -230,6 +120,61 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::VarNode* op) { return var_id_map_[op->vid]; } +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ +template +Doc ScalarLiteral(DataType dtype, const T& value) { + std::ostringstream os; + if (dtype == DataType::Bool()) { + return Doc::PyBoolLiteral(value != 0); + } else { + os << value; + } + return Doc::Text(os.str()); +} + +// Overload of Expr printing functions +Doc RelaxScriptPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, + bool optional_info) { + Doc printed_expr; + if (meta) { + printed_expr = meta_->GetMetaNode(GetRef(expr.get())); + } else { + printed_expr = VisitNode(expr); + } + return printed_expr; +} + +Doc RelaxScriptPrinter::VisitNode_(const relax::ConstantNode* op) { + Doc doc; + // Print out simple scalars directly. + if (op->data->ndim == 0) { + std::ostringstream os; + DataType dtype = DataType(op->data->dtype); + ICHECK_EQ(op->data->device.device_type, kDLCPU); + auto scalar_val = ScalarLiteral(dtype, 0); + if (dtype == DataType::Int(32)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Int(64)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Float(32)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Float(64)) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Bool()) { + scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } + return doc << scalar_val; + } + // default fall-back, record it as meta node. + // Don't append optional_info. Because the entry function is Print, + // and it will append the optional_info afterwards. + return doc << PrintExpr(GetRef(op), true, false, false); +} + Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowVarNode* op) { if (!var_id_map_.count(op->vid)) { var_id_map_[op->vid] = GetUniqueName(op->name_hint(), "dv"); @@ -435,7 +380,11 @@ Doc RelaxScriptPrinter::VisitAttr_(const tir::FloatImmNode* op) { Doc RelaxScriptPrinter::PrintIRModule(const IRModule& mod) { Doc doc; - doc << "@tvm.script.ir_module" << Doc::NewLine(); + if (ShowMetaData()) { + doc << "@tvm.script.ir_module(metadata=metadata)" << Doc::NewLine(); + } else { + doc << "@tvm.script.ir_module" << Doc::NewLine(); + } doc << "class Module:"; for (const std::pair& pr : mod->functions) { Doc func; @@ -485,8 +434,11 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& param << Print(var) << PrintVarAnnotation(var); params.push_back(param); } - - doc << "@relax.function" << Doc::NewLine(); + if (ShowMetaData()) { + doc << "@relax.function(metadata=metadata)" << Doc::NewLine(); + } else { + doc << "@relax.function" << Doc::NewLine(); + } doc << "def " << name << "(" << Doc::Concat(params, Doc::Text(", ")) << ")"; if (func->ret_type.defined()) { doc << " -> " << Print(func->ret_type); @@ -555,9 +507,14 @@ Doc RelaxScriptPrinter::GetUniqueName(std::string prefix, std::string fallback = return Doc::Text(name_table_.GetUniqueName(prefix)); } -String AsRelaxScript(const ObjectRef& mod) { +bool RelaxScriptPrinter::ShowMetaData() { return show_meta_data_; } + +String AsRelaxScript(const ObjectRef& mod, bool show_meta_data) { ICHECK(mod->IsInstance() || mod->IsInstance()); - return RelaxScriptPrinter().Print(mod).str(); + Doc doc; + runtime::TypedPackedFunc ftyped = nullptr; + doc << TextPrinter(show_meta_data, ftyped).PrintRelax(mod); + return doc.str(); } TVM_REGISTER_GLOBAL("script.AsRelaxScript").set_body_typed(AsRelaxScript); diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 4d4113fef694..29de9a33a3d9 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -81,6 +81,8 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { } else if (base_func.as()) { doc << "@" << var->name_hint; doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(base_func)); + } else if (base_func.as()) { + doc << relax_text_printer_.Print(base_func); } doc << Doc::NewLine(); } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index c3746433e37c..54e3c036fefc 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -28,6 +28,8 @@ #include #include +#include +#include #include #include #include @@ -233,6 +235,121 @@ class RelayTextPrinter : public ExprFunctor, } // namespace relay } // namespace tvm +namespace tvm { +namespace relax { +class RelaxScriptPrinter : public relax::IRFunctor, + public tir::ExprFunctor, + public TypeFunctor, + public AttrFunctor { + public: + explicit RelaxScriptPrinter(bool show_meta_data, TextMetaDataContext* meta) + : show_meta_data_(show_meta_data), meta_(meta) {} + TVM_DLL Doc Print(const ObjectRef& node); + bool ShowMetaData(); + + private: + NameTable name_table_; + /*! \brief Whether to print meta data. */ + bool show_meta_data_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + std::unordered_map var_id_map_; + std::unordered_map dim_var_map_; + + // IR nodes inherited from Relay + Doc VisitNode_(const relay::TupleNode* op) override; + Doc VisitNode_(const relay::GlobalVarNode* op) override; + Doc VisitNode_(const relay::ConstantNode* op) override; + Doc VisitNode_(const relay::CallNode* op) override; + // Doc VisitNode_(const relay::IfNode* op) override; + Doc VisitNode_(const OpNode* op) override; + Doc VisitNode_(const relay::TupleGetItemNode* op) override; + + // IR nodes introduced by Relax + Doc VisitNode_(const relax::VarNode* op) override; + Doc VisitNode_(const relax::DataflowVarNode* op) override; + Doc VisitNode_(const relax::ShapeExprNode* op) override; + Doc VisitNode_(const relax::MatchShapeNode* op) override; + Doc VisitNode_(const relax::VarBindingNode* op) override; + Doc VisitNode_(const relax::BindingBlockNode* op) override; + Doc VisitNode_(const relax::DataflowBlockNode* op) override; + Doc VisitNode_(const relax::SeqExprNode* op) override; + Doc VisitNode_(const relax::FunctionNode* op) override; + Doc VisitNode_(const relax::ExternFuncNode* op) override; + + // PrimExpr nodes allowed in Relax + Doc VisitExpr_(const tir::VarNode* op) override; + Doc VisitExpr_(const tir::IntImmNode* op) override; + Doc VisitExpr_(const tir::AddNode* op) override; + Doc VisitExpr_(const tir::SubNode* op) override; + Doc VisitExpr_(const tir::MulNode* op) override; + Doc VisitExpr_(const tir::DivNode* op) override; + Doc VisitExpr_(const tir::FloorDivNode* op) override; + + Doc PrintIRModule(const IRModule& mod); + Doc PrintPrimFunc(const String& name, const tir::PrimFunc& func); + + Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); + Doc PrintFunctionDef(const Doc& name, const relax::Function& func); + + Doc PrintVarAnnotation(const relax::Var& var); + Doc PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape); + + Doc VisitType_(const relax::ShapeTypeNode* node) override; + Doc VisitType_(const relax::DynTensorTypeNode* node) override; + Doc VisitType_(const relay::TupleTypeNode* node) override; + + Doc PrintAttr(const ObjectRef& attr); + std::vector PrintAttrs(const Attrs& attrs); + Doc VisitAttrDefault_(const Object* op) override; + Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true); + Doc VisitAttr_(const ArrayNode* op) override; + Doc VisitAttr_(const tir::IntImmNode* op) override; + Doc VisitAttr_(const tir::FloatImmNode* op) override; + + Doc GetUniqueName(std::string prefix, std::string fallback); + + /*! + * \brief Attribute printer which prints the attributes as kwargs in a call. + */ + class AttrPrinter : public AttrVisitor { + public: + AttrPrinter(std::vector* docs, RelaxScriptPrinter* parent) : docs(docs), parent_(parent) {} + + template + void PrintKV(const char* key, const T& value) { + Doc doc; + doc << key << "=" << value; + docs->push_back(doc); + } + + void Visit(const char* key, double* value) final { PrintKV(key, *value); } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, int* value) final { PrintKV(key, *value); } + void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } + void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } + void Visit(const char* key, void** value) final { + LOG(FATAL) << "do not allow void as argument"; + } + void Visit(const char* key, DataType* value) final { + PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); + } + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "do not allow NDarray as argument"; + } + void Visit(const char* key, runtime::ObjectRef* obj) final { + PrintKV(key, parent_->PrintAttr(*obj)); + } + + private: + std::vector* docs; + RelaxScriptPrinter* parent_; + }; +}; +} // namespace relax +} // namespace tvm + namespace tvm { namespace tir { @@ -426,6 +543,7 @@ class TextPrinter { show_warning_(show_warning), annotate_(annotate), relay_text_printer_(show_meta_data, &meta_, annotate), + relax_text_printer_(show_meta_data, &meta_), tir_text_printer_(show_meta_data, &meta_) {} /*! \brief whether show meta data */ @@ -440,6 +558,8 @@ class TextPrinter { runtime::TypedPackedFunc annotate_; /*! \brief Relay Text Printer */ relay::RelayTextPrinter relay_text_printer_; + /*! \brief Relax Text Printer */ + relax::RelaxScriptPrinter relax_text_printer_; /*! \brief TIR Text Printer */ tir::TIRTextPrinter tir_text_printer_; @@ -453,6 +573,8 @@ class TextPrinter { (node->IsInstance() || node->IsInstance() || node->IsInstance())) { doc << tir_text_printer_.Print(node); + } else if (node.defined() && node->IsInstance()) { + doc << relax_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); } @@ -470,6 +592,18 @@ class TextPrinter { return doc; } + Doc PrintRelax(const ObjectRef& node) { + relax_text_printer_.Print(node); + Doc doc; + if (show_meta_data_ && !meta_.empty()) { + doc << "metadata = "; + doc << meta_.GetMetaSection(); + doc << Doc::NewLine(); + } + doc << relax_text_printer_.Print(node); + return doc; + } + Doc PrintMod(const IRModule& mod); }; } // namespace tvm diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index fed51acf414f..27c844f2dccc 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -18,7 +18,6 @@ from __future__ import annotations # must import to defer parsing of annotations import pytest import tvm - from tvm import tir, relay, relax from tvm.ir import assert_structural_equal @@ -505,6 +504,28 @@ def f(x: Tensor[(3, 3), "float32"]): assert isinstance(w_bind.value.attrs, relay.op.op_attrs.ShapeOfAttrs) +def test_constant(): + @R.function + def f(x: Tensor[(2, 3), "float32"]): + y1 = relax.const(2, dtype="float32") + y2 = relax.const([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]]) + z = add(x, y1) + r = add(z, y2) + return r + + x = f.params[0] + bind_0 = f.body.blocks[0].bindings[0] + assert bind_0.var.name_hint == "y1" + bind_1 = f.body.blocks[0].bindings[1] + assert bind_1.var.name_hint == "y2" + bind_2 = f.body.blocks[0].bindings[2] + assert bind_2.var.name_hint == "z" + bind_3 = f.body.blocks[0].bindings[3] + assert bind_3.var.name_hint == "r" + check_call(bind_2.value, "add", [x, bind_0.var]) + check_call(bind_3.value, "add", [bind_2.var, bind_1.var]) + + def test_primexpr_arithmetic(): @R.function def f(x: Tensor[(n, m), "float32"]): diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index f42302d682e8..5377d00d57eb 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations # must import to defer parsing of annotations + import pytest import tvm @@ -24,14 +25,16 @@ from tvm.ir import structural_equal, assert_structural_equal import tvm.script +from tvm.relax.utils import metadata_partitioner from tvm.script import tir as T, relax as R def check_roundtrip(f_pre): - print(R.parser.astext(f_pre)) - f_post = R.parser.from_source(R.parser.astext(f_pre)) + relax_text = R.parser.astext(f_pre, show_meta_data=True) + f_post = R.parser.from_source(input_func=relax_text) if isinstance(f_pre, tvm.IRModule) and not isinstance(f_post, tvm.IRModule): - f_post = f_post() + global_vars = f_pre.get_global_vars() + f_post = tvm.IRModule({global_vars[0]: f_post}, attrs=metadata) assert_structural_equal(f_pre, f_post, map_free_vars=True) @@ -194,6 +197,80 @@ def foo(x: Tensor): check_roundtrip(foo) +def test_const_irmodule(): + def _gen_meta_data(): + @tvm.script.ir_module + class Module: + @R.function + def my_const(x: Tensor[(2, 3), "float32"]): + y: Tensor[(2, 3), "float32"] = relax.const( + [[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]], dtype="float32" + ) + z: Tensor[(2, 3), "float32"] = relax.add(x, y) + return z + + mod = Module + relax_text = R.parser.astext(mod, show_meta_data=True) + texts = metadata_partitioner(relax_text) + return texts[1] + + json_str = _gen_meta_data() + + @tvm.script.ir_module(metadata=json_str) + class MyModule: + @R.function + def my_const(x: Tensor[(2, 3), "float32"]): + z: Tensor[(2, 3), "float32"] = relax.add(x, meta[relay.Constant][0]) + return z + + my_module = MyModule + + check_roundtrip(my_module) + + +def test_const(): + @R.function + def my_const(x: Tensor[(2, 3), "float32"]): + y1 = relax.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]) + y2 = relax.const(2.1, dtype="float32") + y3 = relax.const([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]) + z = relax.add(x, y1) + r = relax.add(z, y2) + w = relax.add(r, y3) + return w + + check_roundtrip(my_const) + + +def test_const_meta(): + def _get_meta_data(): + @R.function + def my_const(x: Tensor[(2, 3), "float32"]): + y1: Tensor[(2, 3), "float32"] = relax.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]) + y2 = relax.const(2.1, dtype="float32") + y3: Tensor[(2, 3), "float32"] = relax.const([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]) + z: Tensor[(2, 3), "float32"] = relax.add(x, y1) + r: Tensor[(2, 3), "float32"] = relax.add(z, y2) + w: Tensor[(2, 3), "float32"] = relax.add(r, y3) + return w + + relax_text = R.parser.astext(my_const, show_meta_data=True) + texts = metadata_partitioner(relax_text) + return texts[1] + + json_str = _get_meta_data() + + @R.function(metadata=json_str) + def my_const(x: Tensor[(2, 3), "float32"]): + y2 = relax.const(2.1, dtype="float32") + z: Tensor[(2, 3), "float32"] = relax.add(x, meta[relay.Constant][0]) + r: Tensor[(2, 3), "float32"] = relax.add(z, y2) + w: Tensor[(2, 3), "float32"] = relax.add(r, meta[relay.Constant][1]) + return w + + check_roundtrip(my_const) + + def test_class_irmodule(): @tvm.script.ir_module class MyModule: