diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 61b3e13c1630a..eceafec75fa12 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -178,7 +178,7 @@ class RelayExpr : public BaseExpr { class GlobalVar; /*! - * \brief Global variable that leaves in the top-level module. + * \brief Global variable that lives in the top-level module. * * A GlobalVar only refers to function definitions. * This is used to enable recursive calls between function. diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 56f2389ad3851..9e87731dae723 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -141,11 +141,12 @@ enum TypeKind : int { }; /*! - * \brief Type parameter in the function. - * This can be viewed as template parameter in c++ template function. + * \brief Type parameter in functions. + * + * A type variable can be viewed as template parameter in c++ template function. * * For example, in the following pesudo code, - * the TypeVar of f is TypeVar(kind=kShapeVar, var=n). + * the TypeVar of f is TypeVar("n", kind=kShapeVar). * This function can take in a Tensor with shape=(3, 3) and * returns a Tensor with shape=(9,) * diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 6d4e75a23f6b8..ff36b9671fa8f 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -165,7 +165,7 @@ using TypeRelationFn = const TypeReporter& reporter)>; /*! - * \brief User defined type relation, is an input-output relation on types. + * \brief User defined type relation, it is an input-output relation on types. * * TypeRelation is more generalized than type call as it allows inference * of both inputs and outputs. diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py new file mode 100644 index 0000000000000..3a92b4b8c1293 --- /dev/null +++ b/python/tvm/ir/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Common data structures across all IR variants.""" +from .base import SourceName, Span, Node +from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc +from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType +from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType +from .type_relation import TypeCall, TypeRelation +from .tensor_type import TensorType diff --git a/python/tvm/ir/_ffi_api.py b/python/tvm/ir/_ffi_api.py new file mode 100644 index 0000000000000..74f660c509014 --- /dev/null +++ b/python/tvm/ir/_ffi_api.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.ir""" +import tvm._ffi + +# Exports functions registered via TVM_REGISTER_GLOBAL with the "runtime" prefix. +# e.g. TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") +tvm._ffi._init_api("ir", __name__) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py new file mode 100644 index 0000000000000..a9d310b807cfd --- /dev/null +++ b/python/tvm/ir/base.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Span data structures.""" +import tvm._ffi + +from tvm.runtime import Object +from . import _ffi_api + + +class Node(Object): + """Base class of all IR Nodes, implements astext function.""" + def astext(self, show_meta_data=True, annotate=None): + """Get the text format of the expression. + + Parameters + ---------- + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[Object->str] + Optional annotate function to provide additional + information in the comment block. + + Note + ---- + The meta data section is necessary to fully parse the text format. + However, it can contain dumps that are big (e.g constant weights), + so it can be helpful to skip printing the meta data section. + + Returns + ------- + text : str + The text format of the expression. + """ + return _ffi_api.AsText(self, show_meta_data, annotate) + + def __str__(self): + return self.astext(show_meta_data=False) + + +@tvm._ffi.register_object("relay.SourceName") +class SourceName(Object): + """A identifier for a source location. + + Parameters + ---------- + name : str + The name of the source. + """ + def __init__(self, name): + self.__init_handle_by_constructor__(_ffi_api.SourceName, name) + + +@tvm._ffi.register_object("relay.Span") +class Span(Object): + """Specifies a location in a source program. + + Parameters + ---------- + source : SourceName + The source name. + + lineno : int + The line number. + + col_offset : int + The column offset of the location. + """ + def __init__(self, source, lineno, col_offset): + self.__init_handle_by_constructor__( + _ffi_api.Span, source, lineno, col_offset) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py new file mode 100644 index 0000000000000..6b976f24bb1b6 --- /dev/null +++ b/python/tvm/ir/expr.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Common expressions data structures in the IR.""" +import tvm._ffi + +from .base import Node +from . import _ffi_api + +class BaseExpr(Node): + """Base class of all the expressions.""" + + +class PrimExpr(BaseExpr): + """Base class of all primitive expressions. + + PrimExpr is used in the low-level code + optimizations and integer analysis. + """ + + +class RelayExpr(BaseExpr): + """Base class of all non-primitive expressions.""" + @property + def checked_type(self): + """Get the checked type of tvm.relay.Expr. + + Returns + ------- + checked_type : tvm.relay.Type + The checked type. + """ + ret = self._checked_type_ + if ret is None: + raise ValueError("The type checker has not populated" + " the checked_type for this node") + return ret + + +class BaseFunc(RelayExpr): + """Base class of all functions.""" + + +@tvm._ffi.register_object("relay.GlobalVar") +class GlobalVar(RelayExpr): + """A global variable in the IR. + + GlobalVar is used to refer to the global functions + stored in the IRModule. + + Parameters + ---------- + name_hint: str + The name of the variable. + """ + def __init__(self, name_hint): + self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint) + + def __call__(self, *args): + """Call the global variable. + + Parameters + ---------- + args: List[RelayExpr] + The arguments to the call. + + Returns + ------- + call: Call + A call taking the variable as a function. + """ + # pylint: disable=import-outside-toplevel + if all(isinstance(x, RelayExpr) for x in args): + from tvm import relay + return relay.Call(self, args) + arg_types = [type(x) for x in args] + raise RuntimeError( + "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)) diff --git a/python/tvm/ir/tensor_type.py b/python/tvm/ir/tensor_type.py new file mode 100644 index 0000000000000..99286ed13fd23 --- /dev/null +++ b/python/tvm/ir/tensor_type.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Type relation and function for type checking.""" +import tvm._ffi + +from .type import Type +from . import _ffi_api + + +@tvm._ffi.register_object("relay.TensorType") +class TensorType(Type): + """A concrete TensorType in Relay. + + This is the type assigned to tensors with a known dtype and shape. + For example, a tensor of `float32` and `(5, 5)`. + + Parameters + ---------- + shape : List[tvm.ir.PrimExpr] + The shape of the Tensor + + dtype : Optional[str] + The content data type. + """ + def __init__(self, shape, dtype="float32"): + self.__init_handle_by_constructor__( + _ffi_api.TensorType, shape, dtype) + + @property + def concrete_shape(self): + """Get shape of the type as concrete tuple of int. + + Returns + ------- + shape : List[int] + The concrete shape of the Type. + + Raises + ------ + TypeError : If the shape is symbolic + """ + return tuple(int(x) for x in self.shape) diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py new file mode 100644 index 0000000000000..21d1e793482a6 --- /dev/null +++ b/python/tvm/ir/type.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unified type system in the project.""" +from enum import IntEnum +import tvm._ffi + +from .base import Node +from . import _ffi_api + + +class Type(Node): + """The base class of all types.""" + def __eq__(self, other): + """Compare two types for structural equivalence.""" + return bool(_ffi_api.type_alpha_equal(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + def same_as(self, other): + """Compares two Relay types by referential equality.""" + return super().__eq__(other) + + +class TypeKind(IntEnum): + """Possible kinds of TypeVars.""" + Type = 0 + ShapeVar = 1 + BaseType = 2 + Constraint = 4 + AdtHandle = 5 + TypeData = 6 + + +@tvm._ffi.register_object("relay.TypeVar") +class TypeVar(Type): + """Type parameter in functions. + + A type variable represents a type placeholder which will + be filled in later on. This allows the user to write + functions which are generic over types. + + Parameters + ---------- + name_hint: str + The name of the type variable. This name only acts as a hint, and + is not used for equality. + + kind : Optional[TypeKind] + The kind of the type parameter. + """ + def __init__(self, name_hint, kind=TypeKind.Type): + self.__init_handle_by_constructor__( + _ffi_api.TypeVar, name_hint, kind) + + def __call__(self, *args): + """Create a type call from this type. + + Parameters + ---------- + args: List[Type] + The arguments to the type call. + + Returns + ------- + call: Type + The result type call. + """ + # pylint: disable=import-outside-toplevel + from .type_relation import TypeCall + return TypeCall(self, args) + + +@tvm._ffi.register_object("relay.GlobalTypeVar") +class GlobalTypeVar(Type): + """A global type variable that is used for defining new types or type aliases. + + Parameters + ---------- + name_hint: str + The name of the type variable. This name only acts as a hint, and + is not used for equality. + + kind : Optional[TypeKind] + The kind of the type parameter. + """ + def __init__(self, name_hint, kind=TypeKind.AdtHandle): + self.__init_handle_by_constructor__( + _ffi_api.GlobalTypeVar, name_hint, kind) + + def __call__(self, *args): + """Create a type call from this type. + + Parameters + ---------- + args: List[Type] + The arguments to the type call. + + Returns + ------- + call: Type + The result type call. + """ + # pylint: disable=import-outside-toplevel + from .type_relation import TypeCall + return TypeCall(self, args) + + +@tvm._ffi.register_object("relay.TupleType") +class TupleType(Type): + """The type of tuple values. + + Parameters + ---------- + fields : List[Type] + The fields in the tuple + """ + + def __init__(self, fields): + self.__init_handle_by_constructor__( + _ffi_api.TupleType, fields) + + +@tvm._ffi.register_object("relay.TypeConstraint") +class TypeConstraint(Type): + """Abstract class representing a type constraint.""" + + +@tvm._ffi.register_object("relay.FuncType") +class FuncType(Type): + """Function type. + + + A function type consists of a list of type parameters to enable + the definition of generic functions, + a set of type constraints which we omit for the time being, + a sequence of argument types, and a return type. + + We can informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + + Parameters + ---------- + arg_types : List[tvm.relay.Type] + The argument types + + ret_type : tvm.relay.Type + The return type. + + type_params : Optional[List[tvm.relay.TypeVar]] + The type parameters + + type_constraints : Optional[List[tvm.relay.TypeConstraint]] + The type constraints. + """ + def __init__(self, + arg_types, + ret_type, + type_params=None, + type_constraints=None): + if type_params is None: + type_params = [] + if type_constraints is None: + type_constraints = [] + self.__init_handle_by_constructor__( + _ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints) + + +@tvm._ffi.register_object("relay.IncompleteType") +class IncompleteType(Type): + """Incomplete type during type inference. + + kind : Optional[TypeKind] + The kind of the incomplete type. + """ + def __init__(self, kind=TypeKind.Type): + self.__init_handle_by_constructor__( + _ffi_api.IncompleteType, kind) + + +@tvm._ffi.register_object("relay.RefType") +class RelayRefType(Type): + """Reference Type in relay. + + Parameters + ---------- + value: Type + The value type. + """ + def __init__(self, value): + self.__init_handle_by_constructor__(_ffi_api.RelayRefType, value) diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py new file mode 100644 index 0000000000000..63c83d9af0423 --- /dev/null +++ b/python/tvm/ir/type_relation.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Type relation and function for type checking.""" +import tvm._ffi + +from .type import Type, TypeConstraint +from . import _ffi_api + + +class TypeCall(Type): + """Type function application. + + Parameters + ---------- + func: tvm.ir.Type + The function. + + args: List[tvm.ir.Type] + The arguments. + + Returns + ------- + type_call: TypeCall + The type function application. + """ + def __init__(self, func, args): + self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) + + +@tvm._ffi.register_object("relay.TypeRelation") +class TypeRelation(TypeConstraint): + """User defined type relation, it is an input-output relation on types. + + TypeRelation is more generalized than TypeCall as it allows inference + of both inputs and outputs. + + Parameters + ---------- + func : EnvFunc + User defined relation function. + + args : [tvm.ir.Type] + List of types to the func. + + num_inputs : int + Number of input arguments in args, + this act as a hint for type inference. + + attrs : Attrs + The attribute attached to the relation information + + Returns + ------- + type_relation : tvm.ir.TypeRelation + The type relation. + """ + def __init__(self, func, args, num_inputs, attrs): + self.__init_handle_by_constructor__( + _ffi_api.TypeRelation, func, args, num_inputs, attrs) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 2432ec31cfe57..ac47673ddb7e4 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -74,7 +74,7 @@ Type = ty.Type TupleType = ty.TupleType TensorType = ty.TensorType -Kind = ty.Kind +Kind = ty.TypeKind TypeVar = ty.TypeVar ShapeVar = ty.ShapeVar TypeConstraint = ty.TypeConstraint @@ -88,7 +88,7 @@ Any = ty.Any # Expr -Expr = expr.Expr +Expr = expr.RelayExpr Constant = expr.Constant Tuple = expr.Tuple Var = expr.Var diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 0fd1c105a3d1c..3ae89af01f6e9 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -37,6 +37,7 @@ def __new__(cls, *args, **kwds): return deque.__new__(cls, *args, **kwds) import tvm +import tvm.ir._ffi_api from . import module from .base import Span, SourceName @@ -190,7 +191,7 @@ def _wrapper(*args, **kwargs): sp = Span(sn, line, col) if isinstance(ast, tvm.relay.expr.TupleWrapper): ast = ast.astuple() - ast.set_span(sp) + tvm.ir._ffi_api.NodeSetSpan(ast, sp) return ast return _wrapper @@ -243,7 +244,7 @@ def exit_type_param_scope(self) -> Scope[ty.TypeVar]: """Pop off the current TypeVar scope and return it.""" return self.type_var_scopes.popleft() - def mk_typ(self, name: str, kind: ty.Kind) -> ty.TypeVar: + def mk_typ(self, name: str, kind: ty.TypeKind) -> ty.TypeVar: """Create a new TypeVar and add it to the TypeVar scope.""" typ = ty.TypeVar(name, kind) self.type_var_scopes[0].append((name, typ)) @@ -274,7 +275,7 @@ def _type_expr_name(self, e): if isinstance(e, adt.Constructor): return "`{0}` ADT constructor".format(e.belong_to.name_hint) if isinstance(e, ty.GlobalTypeVar): - if e.kind == ty.Kind.AdtHandle: + if e.kind == ty.TypeKind.AdtHandle: return "ADT definition" return "function definition" @@ -492,7 +493,7 @@ def mk_func( assert type_params for ty_param in type_params: name = ty_param.getText() - self.mk_typ(name, ty.Kind.Type) + self.mk_typ(name, ty.TypeKind.Type) var_list, attr_list = self.visit(ctx.argList()) if var_list is None: @@ -528,13 +529,13 @@ def handle_adt_header( ctx: Union[RelayParser.ExternAdtDefnContext, RelayParser.AdtDefnContext]): """Handles parsing of the name and type params of an ADT definition.""" adt_name = ctx.generalIdent().getText() - adt_var = self.mk_global_typ_var(adt_name, ty.Kind.AdtHandle) + adt_var = self.mk_global_typ_var(adt_name, ty.TypeKind.AdtHandle) # parse type params type_params = ctx.typeParamList() if type_params is None: type_params = [] else: - type_params = [self.mk_typ(type_ident.getText(), ty.Kind.Type) + type_params = [self.mk_typ(type_ident.getText(), ty.TypeKind.Type) for type_ident in type_params.typeExpr()] return adt_var, type_params diff --git a/python/tvm/relay/adt.py b/python/tvm/relay/adt.py index 7f7496b1a4070..ab9f75ae30d6e 100644 --- a/python/tvm/relay/adt.py +++ b/python/tvm/relay/adt.py @@ -19,7 +19,7 @@ from .base import RelayNode, register_relay_node, Object from . import _make from .ty import Type -from .expr import Expr, Call +from .expr import ExprWithOp, RelayExpr, Call class Pattern(RelayNode): @@ -113,7 +113,7 @@ def __init__(self, patterns=None): @register_relay_node -class Constructor(Expr): +class Constructor(RelayExpr): """Relay ADT constructor.""" def __init__(self, name_hint, inputs, belong_to): @@ -206,7 +206,7 @@ def __init__(self, lhs, rhs): @register_relay_node -class Match(Expr): +class Match(ExprWithOp): """Pattern matching expression in Relay.""" def __init__(self, data, clauses, complete=True): diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py index 5220b86501792..c4158781756cf 100644 --- a/python/tvm/relay/analysis.py +++ b/python/tvm/relay/analysis.py @@ -20,9 +20,10 @@ This file contains the set of passes for Relay, which exposes an interface for configuring the passes and scripting them in Python. """ +from tvm.ir import RelayExpr + from . import _analysis from . import _make -from .expr import Expr from .ty import Type from .module import Module from .feature import Feature @@ -400,7 +401,7 @@ def structural_hash(value): result : int The hash value """ - if isinstance(value, Expr): + if isinstance(value, RelayExpr): return int(_analysis._expr_hash(value)) elif isinstance(value, Type): return int(_analysis._type_hash(value)) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index bc041252a668d..5f113f5c33941 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck +# pylint: disable=no-else-return, unidiomatic-typecheck, unused-import """The base node types for the Relay language.""" import tvm._ffi from tvm.runtime import Object +from tvm.ir import SourceName, Span, Node as RelayNode from . import _make from . import _expr from . import _base @@ -52,55 +53,6 @@ def register_relay_attr_node(type_key=None): return tvm._ffi.register_object(type_key) -class RelayNode(Object): - """Base class of all Relay nodes.""" - def astext(self, show_meta_data=True, annotate=None): - """Get the text format of the expression. - - Parameters - ---------- - show_meta_data : bool - Whether to include meta data section in the text - if there is meta data. - - annotate: Optional[relay.Expr->str] - Optional annotate function to provide additional - information in the comment block. - - Note - ---- - The meta data section is necessary to fully parse the text format. - However, it can contain dumps that are big (e.g constant weights), - so it can be helpful to skip printing the meta data section. - - Returns - ------- - text : str - The text format of the expression. - """ - return _expr.AsText(self, show_meta_data, annotate) - - def set_span(self, span): - _base.set_span(self, span) - - def __str__(self): - return self.astext(show_meta_data=False) - - -@register_relay_node -class Span(RelayNode): - """Specifies a location in a source program.""" - - def __init__(self, source, lineno, col_offset): - self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) - -@register_relay_node -class SourceName(RelayNode): - """A identifier for a source location""" - - def __init__(self, name): - self.__init_handle_by_constructor__(_make.SourceName, name) - @register_relay_node class Id(Object): """Unique identifier(name) used in Var. diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 5add5e76a6802..e5259fbc0da8f 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +# pylint: disable=no-else-return, invalid-name, unused-import """The expression nodes of Relay.""" from __future__ import absolute_import from numbers import Number as _Number @@ -22,33 +22,21 @@ import numpy as _np from tvm._ffi import base as _base from tvm.runtime import NDArray, convert, ndarray as _nd +from tvm.ir import RelayExpr, GlobalVar, BaseFunc from .base import RelayNode, register_relay_node from . import _make from . import _expr from . import ty as _ty +# alias relay expr as Expr. +Expr = RelayExpr # will be registered afterwards _op_make = None -class Expr(RelayNode): - """The base type for all Relay expressions.""" - @property - def checked_type(self): - """Get the checked type of tvm.relay.Expr. - - Returns - ------- - checked_type : tvm.relay.Type - The checked type. - """ - ret = self._checked_type_ - if ret is None: - raise ValueError("The type checker has not populated" - " the checked_type for this node") - return ret - +class ExprWithOp(RelayExpr): + """Basetype of all relay expressions that defines op overloading.""" def astype(self, dtype): """Cast the content type of the current data to dtype. @@ -173,7 +161,7 @@ def __call__(self, *args): return Call(self, args) @register_relay_node -class Constant(Expr): +class Constant(ExprWithOp): """A constant expression in Relay. Parameters @@ -186,7 +174,7 @@ def __init__(self, data): @register_relay_node -class Tuple(Expr): +class Tuple(ExprWithOp): """Tuple expression that groups several fields together. Parameters @@ -210,7 +198,7 @@ def astype(self, _): @register_relay_node -class Var(Expr): +class Var(ExprWithOp): """A local variable in Relay. Local variable can be used to declare input @@ -238,33 +226,7 @@ def name_hint(self): @register_relay_node -class GlobalVar(Expr): - """A global variable in Tvm.Relay. - - GlobalVar is used to refer to the global functions - stored in the module. - - Parameters - ---------- - name_hint: str - The name of the variable. - """ - def __init__(self, name_hint): - self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) - - def __call__(self, *args): - """Invoke the gobal function. - - Parameters - ---------- - args: List[relay.Expr] - Arguments. - """ - return Call(self, args, None, None) - - -@register_relay_node -class Function(Expr): +class Function(BaseFunc): """A function declaration expression. Parameters @@ -320,7 +282,7 @@ def set_attribute(self, name, ref): @register_relay_node -class Call(Expr): +class Call(ExprWithOp): """Function call node in Relay. Call node corresponds the operator application node @@ -349,7 +311,7 @@ def __init__(self, op, args, attrs=None, type_args=None): @register_relay_node -class Let(Expr): +class Let(ExprWithOp): """Let variable binding expression. Parameters @@ -369,7 +331,7 @@ def __init__(self, variable, value, body): @register_relay_node -class If(Expr): +class If(ExprWithOp): """A conditional expression in Relay. Parameters @@ -389,7 +351,7 @@ def __init__(self, cond, true_branch, false_branch): @register_relay_node -class TupleGetItem(Expr): +class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. Parameters @@ -406,7 +368,7 @@ def __init__(self, tuple_value, index): @register_relay_node -class RefCreate(Expr): +class RefCreate(ExprWithOp): """Create a new reference from initial value. Parameters ---------- @@ -418,7 +380,7 @@ def __init__(self, value): @register_relay_node -class RefRead(Expr): +class RefRead(ExprWithOp): """Get the value inside the reference. Parameters ---------- @@ -430,7 +392,7 @@ def __init__(self, ref): @register_relay_node -class RefWrite(Expr): +class RefWrite(ExprWithOp): """ Update the value inside the reference. The whole expression will evaluate to an empty tuple. @@ -445,7 +407,7 @@ def __init__(self, ref, value): self.__init_handle_by_constructor__(_make.RefWrite, ref, value) -class TempExpr(Expr): +class TempExpr(ExprWithOp): """Baseclass of all TempExpr. TempExprs are pass specific expression that can be diff --git a/python/tvm/relay/memory_alloc.py b/python/tvm/relay/memory_alloc.py index f93aa9eeaf2f6..d61c6f1d6fbab 100644 --- a/python/tvm/relay/memory_alloc.py +++ b/python/tvm/relay/memory_alloc.py @@ -176,7 +176,7 @@ def visit_call(self, call): view = LinearizeRetType(ret_type) out_types = view.unpack() - is_dynamic = ret_type.is_dynamic() + is_dynamic = ty.type_has_any(ret_type) # TODO(@jroesch): restore this code, more complex then it seems # for arg in call.args: # is_dynamic = is_dynamic or arg.checked_type.is_dynamic() diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 5513bd711c4ff..68704ed7072b8 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -85,7 +85,7 @@ def __setitem__(self, var, val): return self._add(var, val) def _add(self, var, val, update=False): - if isinstance(val, _expr.Expr): + if isinstance(val, _expr.RelayExpr): if isinstance(var, _base.string_types): if _module.Module_ContainGlobalVar(self, var): var = _module.Module_GetGlobalVar(self, var) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index c2ec6ad2d22d7..bcd58ba5b1b12 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -41,7 +41,6 @@ from . import _transform from . import _reduce from . import _algorithm -from ..expr import Expr from ..base import register_relay_node diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index f9bc853282bb6..c74201ef9c1f3 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -20,13 +20,13 @@ import tvm._ffi from ..base import register_relay_node -from ..expr import Expr +from ..expr import RelayExpr from ...api import register_func from ...build_module import lower, build from . import _make @register_relay_node -class Op(Expr): +class Op(RelayExpr): """A Relay operator definition.""" def __init__(self): diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index 356fe0beb0da5..13d7f9197e79f 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -14,133 +14,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +# pylint: disable=invalid-name, unused-import """The type nodes of the Relay language.""" -from enum import IntEnum +from tvm.ir import Type, TypeKind, TypeVar, GlobalTypeVar +from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType +from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType + from .base import RelayNode, register_relay_node from . import _make Any = _make.Any -class Type(RelayNode): - """The base type for all Relay types.""" - - def __eq__(self, other): - """Compare two Relay types for structural equivalence using - alpha equivalence. - """ - return bool(_make._alpha_equal(self, other)) - - def __ne__(self, other): - return not self.__eq__(other) - - def same_as(self, other): - """Compares two Relay types by referential equality.""" - return super().__eq__(other) - - def __call__(self, *args): - """Create a type call from this type. +def type_has_any(tensor_type): + """Check whether type has any as a shape. - Parameters - ---------- - args: List[relay.Type] - The arguments to the type call. - - Returns - ------- - call: relay.TypeCall - """ - return TypeCall(self, args) - - def is_dynamic(self): - return _make.IsDynamic(self) - -@register_relay_node -class TensorType(Type): - """A concrete TensorType in Relay. - - This is the type assigned to tensors with a known dtype and shape. For - example, a tensor of `float32` and `(5, 5)`. - - Parameters - ---------- - shape : List[tvm.Expr] - The shape of the Tensor - - dtype : Optional[str] - The content data type. - Default to "float32". + tensor_type : Type + The type to be inspected Returns ------- - tensor_type : tvm.relay.TensorType - The tensor type. - """ - def __init__(self, shape, dtype="float32"): - self.__init_handle_by_constructor__( - _make.TensorType, shape, dtype) - - @property - def concrete_shape(self): - """Get shape of the type as concrete tuple of int. - - Returns - ------- - shape : List[int] - The concrete shape of the Type. - - Raises - ------ - TypeError : If the shape is symbolic - """ - return tuple(int(x) for x in self.shape) - - -class Kind(IntEnum): - """The kind of a type parameter, represents a variable shape, - base type, type, or dimension. - - This controls what a type parameter is allowed to be instantiated - with. For example one's of kind BaseType can only be `float32`, `int32`, - and so on. - """ - Type = 0 - ShapeVar = 1 - BaseType = 2 - Shape = 3 - Constraint = 4 - AdtHandle = 5 - TypeData = 6 - -@register_relay_node -class TypeVar(Type): - """A type variable used for generic types in Relay, - see tvm/relay/type.h for more details. - - A type variable represents a type placeholder which will - be filled in later on. This allows the user to write - functions which are generic over types. + has_any : bool + The check result. """ + return _make.IsDynamic(tensor_type) - def __init__(self, name_hint, kind=Kind.Type): - """Construct a TypeVar. - - Parameters - ---------- - name_hint: str - The name of the type variable. This name only acts as a hint, and - is not used for equality. - - kind : Optional[Kind] - The kind of the type parameter. - Default to Kind.Type. - - Returns - ------- - type_var : tvm.relay.TypeVar - The type variable. - """ - self.__init_handle_by_constructor__(_make.TypeVar, name_hint, kind) def ShapeVar(name): """A helper which constructs a type var of which the shape kind. @@ -154,172 +51,9 @@ def ShapeVar(name): type_var : tvm.relay.TypeVar The shape variable. """ - return TypeVar(name, kind=Kind.ShapeVar) - -@register_relay_node -class GlobalTypeVar(Type): - """A global type variable in Relay. - GlobalTypeVar is used to refer to the global type-level definitions - stored in the environment. - """ - - def __init__(self, name_hint, kind=Kind.AdtHandle): - """Construct a GlobalTypeVar. - - Parameters - ---------- - name_hint: str - The name of the global type variable. This name only acts as a - hint, and is not used for equality. - - kind: Kind, optional - The kind of the type parameter, Kind.AdtHandle by default. - - Returns - ------- - type_var: GlobalTypeVar - The global type variable. - """ - self.__init_handle_by_constructor__(_make.GlobalTypeVar, name_hint, kind) - - -@register_relay_node -class TypeCall(Type): - """Type-level function application in Relay. - A type call applies argument types to a constructor (type-level function). - """ - - def __init__(self, func, args): - """Construct a TypeCall. - Parameters - ---------- - func: tvm.relay.Type - The function. - args: List[tvm.expr.Type] - The arguments. - Returns - ------- - type_call: TypeCall - The type function application. - """ - self.__init_handle_by_constructor__(_make.TypeCall, func, args) - - -@register_relay_node -class TypeConstraint(Type): - """Abstract class representing a type constraint.""" - - -@register_relay_node -class TupleType(Type): - """A tuple type in Relay, see tvm/relay/type.h for more details. - - Lists the type of each field in the tuple. - """ - - def __init__(self, fields): - """Constructs a tuple type - - Parameters - ---------- - fields : List[tvm.relay.Type] - The fields in the tuple - - Returns - ------- - tuple_type : tvm.relay.TupleType - the tuple type - """ - self.__init_handle_by_constructor__(_make.TupleType, fields) - - -@register_relay_node -class FuncType(Type): - """A function type in Relay, see tvm/relay/type.h for more details. - - This is the type assigned to functions in Relay. They consist of - a list of type parameters which enable the definition of generic - functions, a set of type constraints which we omit for the time - being, a sequence of argument types, and a return type. - - We informally write them as: - `forall (type_params), (arg_types) -> ret_type where type_constraints` - - Parameters - ---------- - arg_types : List[tvm.relay.Type] - The argument types - - ret_type : tvm.relay.Type - The return type. - - type_params : Optional[List[tvm.relay.TypeVar]] - The type parameters - - type_constraints : Optional[List[tvm.relay.TypeConstraint]] - The type constraints. - """ - def __init__(self, - arg_types, - ret_type, - type_params=None, - type_constraints=None): - if type_params is None: - type_params = [] - if type_constraints is None: - type_constraints = [] - self.__init_handle_by_constructor__( - _make.FuncType, arg_types, ret_type, type_params, type_constraints) + return TypeVar(name, kind=TypeKind.ShapeVar) -@register_relay_node -class IncompleteType(Type): - """An incomplete type.""" - def __init__(self, kind=Kind.Type): - self.__init_handle_by_constructor__(_make.IncompleteType, kind) - - -@register_relay_node -class TypeRelation(TypeConstraint): - """Type relation in relay. - - Parameters - ---------- - func : EnvFunc - User defined relation function. - - args : [tvm.relay.Type] - List of types to the func. - - num_inputs : int - Number of input arguments in args, - this act as a hint for type inference. - - attrs : Attrs - The attribute attached to the relation information - - Returns - ------- - type_relation : tvm.relay.TypeRelation - The type relation. - """ - def __init__(self, func, args, num_inputs, attrs): - self.__init_handle_by_constructor__(_make.TypeRelation, - func, args, num_inputs, attrs) - - -@register_relay_node -class RefType(Type): - """Reference Type in relay. - - Parameters - ---------- - value: Type - The value type. - """ - def __init__(self, value): - self.__init_handle_by_constructor__(_make.RefType, value) - def scalar_type(dtype): """Creates a scalar type. diff --git a/python/tvm/relay/ty.pyi b/python/tvm/relay/ty.pyi deleted file mode 100644 index cde851160167b..0000000000000 --- a/python/tvm/relay/ty.pyi +++ /dev/null @@ -1,200 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name -"""The type nodes of the Relay language.""" -from enum import IntEnum -from .base import Object, register_relay_node -from . import _make - - -class Type(Object): - """The base type for all Relay types.""" - - def __eq__(self, other): - """Compare two Relay types for structural equivalence using - alpha equivalence. - """ - return bool(_make._type_alpha_eq(self, other)) - - def __ne__(self, other): - return not self.__eq__(other) - - def same_as(self, other): - """Compares two Relay types by referential equality.""" - return super().__eq__(other) - - -@register_relay_node -class TensorType(Type): - """A concrete TensorType in Relay, see tvm/relay/type.h for more details. - - This is the type assigned to tensor's with a known dype and shape. For - example a tensor of `float32` and `(5, 5)`. - """ - - def __init__(self, shape, dtype): - """Construct a tensor type. - - Parameters - ---------- - shape: list of tvm.Expr - dtype: str - - Returns - ------- - tensor_type: The TensorType - """ - self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) - - -class Kind(IntEnum): - """The kind of a type parameter, represents a variable shape, - base type, type, or dimension. - - This controls what a type parameter is allowed to be instantiated - with. For example one's of kind BaseType can only be `float32`, `int32`, - and so on. - """ - ShapeVar = 0 - Shape = 1 - BaseType = 2 - Type = 3 - - -@register_relay_node -class TypeParam(Type): - """A type parameter used for generic types in Relay, - see tvm/relay/type.h for more details. - - A type parameter represents a type placeholder which will - be filled in later on. This allows the user to write - functions which are generic over types. - """ - - def __init__(self, var, kind): - """Construct a TypeParam. - - Parameters - ---------- - var: tvm.expr.Var - The tvm.Var which backs the type parameter. - - kind: Kind - The kind of the type parameter. - - Returns - ------- - type_param: TypeParam - The type parameter. - """ - self.__init_handle_by_constructor__(_make.TypeParam, var, kind) - - -@register_relay_node -class TypeConstraint(Type): - """Abstract class representing a type constraint.""" - pass - - -@register_relay_node -class TupleType(Type): - """A tuple type in Relay, see tvm/relay/type.h for more details. - - Lists the type of each field in the tuple. - """ - - def __init__(self, fields): - """Constructs a tuple type - - Parameters - ---------- - fields: list of tvm.Type - - Returns - ------- - tuple_type: the tuple type - """ - self.__init_handle_by_constructor__(_make.TupleType, fields) - - -@register_relay_node -class FuncType(Type): - """A function type in Relay, see tvm/relay/type.h for more details. - - This is the type assigned to functions in Relay. They consist of - a list of type parameters which enable the definition of generic - functions, a set of type constraints which we omit for the time - being, a sequence of argument types, and a return type. - - We informally write them as: - `forall (type_params), (arg_types) -> ret_type where type_constraints` - """ - - def __init__(self, - arg_types, - ret_type, - type_params, - type_constraints, - ): - """Construct a function type. - - Parameters - ---------- - arg_types: list of Type - ret_type: Type - type_params: list of TypeParam - type_constraints: list of TypeConstraint - - Returns - ------- - func_type: FuncType - The function type. - """ - self.__init_handle_by_constructor__( - _make.FuncType, arg_types, ret_type, type_params, type_constraints) - - -@register_relay_node -class IncompleteType(Type): - """An incomplete type.""" - - def __init__(self, kind=Kind.Type): - self.__init_handle_by_constructor__(_make.IncompleteType, kind) - -@register_relay_node -class TypeRelation(TypeConstraint): - """Type relation in relay. - - Parameters - ---------- - func : EnvFunc - User defined relation function. - - args : list of types - List of types to the func. - - num_inputs: int - Number of input arguments in args, - this act as a hint for type inference. - - attrs : Attrs - The attribute attached to the relation information - """ - def __init__(self, func, args, num_inputs, attrs): - self.__init_handle_by_constructor__(_make.TypeRelation, - func, args, num_inputs, attrs) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index c061587ba3601..78c6879d8cedc 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -154,7 +154,7 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("relay._make.GlobalVar") +TVM_REGISTER_GLOBAL("ir.GlobalVar") .set_body_typed([](std::string name){ return GlobalVar(name); }); diff --git a/src/ir/span.cc b/src/ir/span.cc index 2ea7095c89ac8..d03903c2d3a5a 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -45,7 +45,7 @@ SourceName SourceName::Get(const std::string& name) { return SourceName(GetSourceNameNode(name)); } -TVM_REGISTER_GLOBAL("relay._make.SourceName") +TVM_REGISTER_GLOBAL("ir.SourceName") .set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -70,7 +70,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("relay._make.Span") +TVM_REGISTER_GLOBAL("ir.Span") .set_body_typed(SpanNode::make); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/ir/tensor_type.cc b/src/ir/tensor_type.cc index 5e7c51c72d9b5..57cdebc931fbb 100644 --- a/src/ir/tensor_type.cc +++ b/src/ir/tensor_type.cc @@ -55,7 +55,7 @@ PrimExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_GLOBAL("relay._make.TensorType") +TVM_REGISTER_GLOBAL("ir.TensorType") .set_body_typed([](Array shape, DataType dtype) { return TensorType(shape, dtype); }); diff --git a/src/ir/type.cc b/src/ir/type.cc index 02ddfc9371fd0..e0420aaf754ae 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -33,7 +33,7 @@ PrimType::PrimType(runtime::DataType dtype) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_REGISTER_GLOBAL("relay._make.PrimType") +TVM_REGISTER_GLOBAL("ir.PrimType") .set_body_typed([](runtime::DataType dtype) { return PrimType(dtype); }); @@ -54,7 +54,7 @@ TypeVar::TypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); -TVM_REGISTER_GLOBAL("relay._make.TypeVar") +TVM_REGISTER_GLOBAL("ir.TypeVar") .set_body_typed([](std::string name, int kind) { return TypeVar(name, static_cast(kind)); }); @@ -76,7 +76,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); -TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") +TVM_REGISTER_GLOBAL("ir.GlobalTypeVar") .set_body_typed([](std::string name, int kind) { return GlobalTypeVar(name, static_cast(kind)); }); @@ -102,7 +102,7 @@ FuncType::FuncType(tvm::Array arg_types, TVM_REGISTER_NODE_TYPE(FuncTypeNode); -TVM_REGISTER_GLOBAL("relay._make.FuncType") +TVM_REGISTER_GLOBAL("ir.FuncType") .set_body_typed([](tvm::Array arg_types, Type ret_type, tvm::Array type_params, @@ -131,7 +131,7 @@ TupleType TupleType::Empty() { TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("relay._make.TupleType") +TVM_REGISTER_GLOBAL("ir.TupleType") .set_body_typed([](Array fields) { return TupleType(fields); }); @@ -151,7 +151,7 @@ IncompleteType::IncompleteType(TypeKind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); -TVM_REGISTER_GLOBAL("relay._make.IncompleteType") +TVM_REGISTER_GLOBAL("ir.IncompleteType") .set_body_typed([](int kind) { return IncompleteType(static_cast(kind)); }); @@ -169,7 +169,7 @@ RelayRefType::RelayRefType(Type value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relay._make.RefType") +TVM_REGISTER_GLOBAL("ir.RelayRefType") .set_body_typed([](Type value) { return RelayRefType(value); }); diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index 1d80f95b10c95..bd79c9c7fd160 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -35,7 +35,7 @@ TypeCall::TypeCall(Type func, tvm::Array args) { TVM_REGISTER_NODE_TYPE(TypeCallNode); -TVM_REGISTER_GLOBAL("relay._make.TypeCall") +TVM_REGISTER_GLOBAL("ir.TypeCall") .set_body_typed([](Type func, Array type) { return TypeCall(func, type); }); @@ -61,7 +61,7 @@ TypeRelation::TypeRelation(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); -TVM_REGISTER_GLOBAL("relay._make.TypeRelation") +TVM_REGISTER_GLOBAL("ir.TypeRelation") .set_body_typed([](TypeRelationFn func, Array args, int num_inputs, diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 0fa4da5b5077e..00bf70b5d2895 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -131,6 +131,7 @@ class RelayTextPrinter : } else if (node.as()) { return PrintMod(Downcast(node)); } else { + // default module. std::ostringstream os; os << node; return Doc() << os.str(); @@ -905,20 +906,18 @@ static const char* kSemVer = "v0.0.4"; // - relay_text_printer.cc (specific printing logics for relay) // - tir_text_printer.cc (specific printing logics for TIR) std::string PrettyPrint(const ObjectRef& node) { - Doc doc; - doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); - return doc.str(); + return AsText(node, false, nullptr); } std::string AsText(const ObjectRef& node, bool show_meta_data, runtime::TypedPackedFunc annotate) { Doc doc; - doc << kSemVer << Doc::NewLine() - << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); + doc << kSemVer << Doc::NewLine(); + doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); return doc.str(); } -TVM_REGISTER_GLOBAL("relay._expr.AsText") +TVM_REGISTER_GLOBAL("ir.AsText") .set_body_typed(AsText); } // namespace tvm diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 2d07f6131f135..48634bafa744a 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -599,6 +599,11 @@ TVM_REGISTER_GLOBAL("relay._make._alpha_equal") return AlphaEqualHandler(false, false).Equal(a, b); }); +TVM_REGISTER_GLOBAL("ir.type_alpha_equal") +.set_body_typed([](Type a, Type b) { + return AlphaEqual(a, b); +}); + TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal") .set_body_typed([](ObjectRef a, ObjectRef b) { bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 85b17b5a33e0a..22423b8dfe5f5 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -33,7 +33,7 @@ using namespace tvm::runtime; TVM_REGISTER_NODE_TYPE(IdNode); -TVM_REGISTER_GLOBAL("relay._base.set_span") +TVM_REGISTER_GLOBAL("ir.NodeSetSpan") .set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { rn->span = sp;