Skip to content

Commit

Permalink
[IR] Update the type_keys to reflect the code-org (#5074)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Mar 15, 2020
1 parent 7c5ff50 commit 6027412
Show file tree
Hide file tree
Showing 18 changed files with 255 additions and 118 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}

static constexpr const char* _type_key = "relay.GlobalVar";
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class IRModuleNode : public Object {
*/
TVM_DLL std::unordered_set<std::string> Imports() const;

static constexpr const char* _type_key = "relay.Module";
static constexpr const char* _type_key = "IRModule";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);

private:
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/span.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }

static constexpr const char* _type_key = "relay.SourceName";
static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};

Expand Down Expand Up @@ -89,7 +89,7 @@ class SpanNode : public Object {

TVM_DLL static Span make(SourceName source, int lineno, int col_offset);

static constexpr const char* _type_key = "relay.Span";
static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class PassContextNode : public Object {
v->Visit("disabled_pass", &disabled_pass);
}

static constexpr const char* _type_key = "relay.PassContext";
static constexpr const char* _type_key = "transform.PassContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};

Expand Down Expand Up @@ -206,7 +206,7 @@ class PassInfoNode : public Object {
v->Visit("required", &required);
}

static constexpr const char* _type_key = "relay.PassInfo";
static constexpr const char* _type_key = "transform.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};

Expand Down Expand Up @@ -265,7 +265,7 @@ class PassNode : public Object {

void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.Pass";
static constexpr const char* _type_key = "transform.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
};

Expand Down
18 changes: 10 additions & 8 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class TypeNode : public Object {
*/
mutable Span span;

static constexpr const char* _type_key = "relay.Type";
static constexpr const char* _type_key = "Type";
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};

Expand Down Expand Up @@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}

static constexpr const char* _type_key = "relay.PrimType";
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};

Expand Down Expand Up @@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TypeVar";
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};

Expand Down Expand Up @@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}

static constexpr const char* _type_key = "relay.GlobalTypeVar";
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};

Expand Down Expand Up @@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TupleType";
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};

Expand Down Expand Up @@ -289,7 +289,7 @@ inline Type VoidType() {
*/
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
static constexpr const char* _type_key = "TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};

Expand Down Expand Up @@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.FuncType";
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};

Expand Down Expand Up @@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.IncompleteType";
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};

Expand Down Expand Up @@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}

// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TypeCall";
static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};

Expand Down Expand Up @@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
// solver is not serializable.
void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.TypeReporter";
static constexpr const char* _type_key = "TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
};

Expand Down Expand Up @@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.TypeRelation";
static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __str__(self):
return _ffi_api.PrettyPrint(self)


@tvm._ffi.register_object("relay.SourceName")
@tvm._ffi.register_object("SourceName")
class SourceName(Object):
"""A identifier for a source location.
Expand All @@ -69,7 +69,7 @@ def __init__(self, name):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)


@tvm._ffi.register_object("relay.Span")
@tvm._ffi.register_object("Span")
class Span(Object):
"""Specifies a location in a source program.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def checked_type(self):
return ret


@tvm._ffi.register_object("relay.GlobalVar")
@tvm._ffi.register_object("GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,35 @@ def _ftype_var(item, nodes):
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
assert item["type_key"].startswith("relay.")
item["type_key"] = item["type_key"][len("relay."):]
return item

def _rename(new_name):
def _convert(item, _):
item["type_key"] = new_name
return item
return _convert

node_map = {
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
"relay.FuncType": _rename("FuncType"),
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
"relay.GlobalVar": _rename("GlobalVar"),
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
}
return create_updater(node_map, "0.6", "0.7")

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from . import _ffi_api


@tvm._ffi.register_object("relay.Module")
@tvm._ffi.register_object("IRModule")
class IRModule(Node):
"""IRModule that holds functions and type definitions.
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from . import _ffi_transform_api

@tvm._ffi.register_object("relay.PassInfo")
@tvm._ffi.register_object("transform.PassInfo")
class PassInfo(Object):
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
Expand All @@ -51,7 +51,7 @@ def __init__(self, opt_level, name, required=None):
_ffi_transform_api.PassInfo, opt_level, name, required)


@tvm._ffi.register_object("relay.PassContext")
@tvm._ffi.register_object("transform.PassContext")
class PassContext(Object):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
Expand Down Expand Up @@ -112,7 +112,7 @@ def current():
return _ffi_transform_api.GetCurrentPassContext()


@tvm._ffi.register_object("relay.Pass")
@tvm._ffi.register_object("transform.Pass")
class Pass(Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
Expand Down Expand Up @@ -141,7 +141,7 @@ def __call__(self, mod):
return _ffi_transform_api.RunPass(self, mod)


@tvm._ffi.register_object("relay.ModulePass")
@tvm._ffi.register_object("transform.ModulePass")
class ModulePass(Pass):
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
Expand All @@ -152,7 +152,7 @@ class ModulePass(Pass):
"""


@tvm._ffi.register_object("relay.Sequential")
@tvm._ffi.register_object("transform.Sequential")
class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
Expand Down
25 changes: 19 additions & 6 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,20 @@ class TypeKind(IntEnum):
TypeData = 6


@tvm._ffi.register_object("relay.TypeVar")
class PrimType(Type):
"""Primitive data type in the low level IR
Parameters
----------
dtype : str
The runtime data type relates to the primtype.
"""
def __init__(self, dtype):
self.__init_handle_by_constructor__(
_ffi_api.PrimType, dtype)


@tvm._ffi.register_object("TypeVar")
class TypeVar(Type):
"""Type parameter in functions.
Expand Down Expand Up @@ -85,7 +98,7 @@ def __call__(self, *args):
return TypeCall(self, args)


@tvm._ffi.register_object("relay.GlobalTypeVar")
@tvm._ffi.register_object("GlobalTypeVar")
class GlobalTypeVar(Type):
"""A global type variable that is used for defining new types or type aliases.
Expand Down Expand Up @@ -120,7 +133,7 @@ def __call__(self, *args):
return TypeCall(self, args)


@tvm._ffi.register_object("relay.TupleType")
@tvm._ffi.register_object("TupleType")
class TupleType(Type):
"""The type of tuple values.
Expand All @@ -135,12 +148,12 @@ def __init__(self, fields):
_ffi_api.TupleType, fields)


@tvm._ffi.register_object("relay.TypeConstraint")
@tvm._ffi.register_object("TypeConstraint")
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""


@tvm._ffi.register_object("relay.FuncType")
@tvm._ffi.register_object("FuncType")
class FuncType(Type):
"""Function type.
Expand Down Expand Up @@ -179,7 +192,7 @@ def __init__(self,
_ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)


@tvm._ffi.register_object("relay.IncompleteType")
@tvm._ffi.register_object("IncompleteType")
class IncompleteType(Type):
"""Incomplete type during type inference.
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/ir/type_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import _ffi_api


@tvm._ffi.register_object("TypeCall")
class TypeCall(Type):
"""Type function application.
Expand All @@ -41,7 +42,7 @@ def __init__(self, func, args):
self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)


@tvm._ffi.register_object("relay.TypeRelation")
@tvm._ffi.register_object("TypeRelation")
class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types.
Expand Down
4 changes: 2 additions & 2 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
*/
PassInfo Info() const override { return pass_info; }

static constexpr const char* _type_key = "relay.ModulePass";
static constexpr const char* _type_key = "transform.ModulePass";
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
};

Expand Down Expand Up @@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;

static constexpr const char* _type_key = "relay.Sequential";
static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};

Expand Down
Loading

0 comments on commit 6027412

Please sign in to comment.