diff --git a/CMakeLists.txt b/CMakeLists.txt index 7d5c04a4f1ef..b823528c4817 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -125,6 +125,8 @@ assign_source_group("Include" ${GROUP_INCLUDE}) # Source file lists file(GLOB COMPILER_SRCS + src/node/*.cc + src/ir/*.cc src/api/*.cc src/arithmetic/*.cc src/autotvm/*.cc @@ -132,7 +134,6 @@ file(GLOB COMPILER_SRCS src/lang/*.cc src/pass/*.cc src/op/*.cc - src/node/*.cc src/schedule/*.cc ) diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h new file mode 100644 index 000000000000..8cbfff73d0c8 --- /dev/null +++ b/include/tvm/ir/span.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/span.h + * \brief Span information for debugging purposes. + */ +#ifndef TVM_IR_SPAN_H_ +#define TVM_IR_SPAN_H_ + +#include +#include +#include + +namespace tvm { +/*! + * \brief The source name in the Span + * \sa SourceNameNode, Span + */ +class SourceName; +/*! + * \brief The name of a source fragment. + */ +class SourceNameNode : public Object { + public: + /*! \brief The source name. */ + std::string name; + // override attr visitor + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } + + static constexpr const char* _type_key = "relay.SourceName"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); +}; + +/*! + * \brief The source name of a file span. + * \sa SourceNameNode, Span + */ +class SourceName : public ObjectRef { + public: + /*! + * \brief Get an SourceName for a given operator name. + * Will raise an error if the source name has not been registered. + * \param name Name of the operator. + * \return SourceName valid throughout program lifetime. + */ + TVM_DLL static SourceName Get(const std::string& name); + + TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); +}; + +/*! + * \brief Span information for debugging purposes + */ +class Span; +/*! + * \brief Stores locations in frontend source that generated a node. + */ +class SpanNode : public Object { + public: + /*! \brief The source name */ + SourceName source; + /*! \brief Line number */ + int lineno; + /*! \brief column offset */ + int col_offset; + // override attr visitor + void VisitAttrs(AttrVisitor* v) { + v->Visit("source", &source); + v->Visit("lineno", &lineno); + v->Visit("col_offset", &col_offset); + } + + TVM_DLL static Span make(SourceName source, int lineno, int col_offset); + + static constexpr const char* _type_key = "relay.Span"; + TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); +}; + + +class Span : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); +}; + +} // namespace tvm +#endif // TVM_IR_SPAN_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h new file mode 100644 index 000000000000..ffe1ba876ac6 --- /dev/null +++ b/include/tvm/ir/type.h @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/type.h + * \brief IR/AST nodes for the unified type system in TVM. + * + * We use Relay's type system as the unified type system + * throughout the stack. + * + * This file contains types that are common across IR variants. + * + * ## Relation between Type and runtime::DataType + * + * Besides Type, we also store a dtype field in some of the low-level IR's Expr. + * runtime::DataType(dtype) provides coarse grained type information + * during compile time and runtime. It is eagerly built in + * low-level expression construction and can be used for + * quick type checking in the low-level IR. + * For example, when an Expr's dtype is int32, + * we know for sure that its type is also int32. + * + * On the other hand, Type provides more fine grained information. + * For example, a low level expression can have DataType::Handle() as + * its dtype and MemRef[float32] as its type. + * Types are usually lazily constructed via type checking, + * so they may not readily be available during IR construction. + * + * The unified Type serves as a common bridge across IR dialects. + * For example, we require all the functions to have a type signature, + * which allow us to build cross dialect function calls. + */ +#ifndef TVM_IR_TYPE_H_ +#define TVM_IR_TYPE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { + +/*! \brief Base type of all the types. */ +class TypeNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "relay.Type"; + TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); +}; + +/*! + * \brief Type is the base type of all types. + * + * Relay's type system contains following two key concepts: + * + * - PrimitiveType: type of primitive type values used in the low-level IR. + * - TensorType: type of certain Tensor values in the expression. + * - FunctionType: the type of the function. + * + * There are also advanced types to support generic(polymorphic types), + * which can be ignored when first reading the code base. + */ +class Type : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Type, ObjectRef, TypeNode); +}; + +/*! \brief Possible kinds of TypeVars. */ +enum TypeKind : int { + kType = 0, + /*! \brief Template variable in shape expression. */ + kShapeVar = 1, + kBaseType = 2, + kShape = 3, + kConstraint = 4, + kAdtHandle = 5, + kTypeData = 6 +}; + +/*! + * \brief Type parameter in the function. + * This can be viewed as template parameter in c++ template function. + * + * For example, in the following pesudo code, + * the TypeVar of f is TypeVar(kind=kShapeVar, var=n). + * This function can take in a Tensor with shape=(3, 3) and + * returns a Tensor with shape=(9,) + * + * \code + * + * template + * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] + * + * \endcode + * \sa TypeVarNode The actual container class of TypeVar + */ +class TypeVar; +/*! \brief TypeVar container node */ +class TypeVarNode : public TypeNode { + public: + /*! + * \brief The name of the variable, + * this only acts as a hint to the user, + * and is not used for equality. + */ + std::string name_hint; + /*! \brief The kind of type parameter */ + TypeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("kind", &kind); + v->Visit("span", &span); + } + + TVM_DLL static TypeVar make(std::string name, TypeKind kind); + + static constexpr const char* _type_key = "relay.TypeVar"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); +}; + +class TypeVar : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); +}; + +/*! + * \brief A global type variable that is used for defining new types or type aliases. + */ +class GlobalTypeVar; +/*! \brief GlobalTypeVar container node */ +class GlobalTypeVarNode : public TypeNode { + public: + /*! + * \brief The name of the variable, + * this only acts as a hint to the user, + * and is not used for equality. + */ + std::string name_hint; + /*! \brief The kind of type parameter */ + TypeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name_hint", &name_hint); + v->Visit("kind", &kind); + } + + TVM_DLL static GlobalTypeVar make(std::string name, TypeKind kind); + + static constexpr const char* _type_key = "relay.GlobalTypeVar"; + TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); +}; + +class GlobalTypeVar : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); +}; + +/*! + * \brief Potential Constraints in the type. + * \note This is reserved for future use. + */ +class TypeConstraint; +/*! \brief TypeConstraint container node. */ +class TypeConstraintNode : public TypeNode { + public: + static constexpr const char* _type_key = "relay.TypeConstraint"; + TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); +}; + +class TypeConstraint : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode); +}; + +class FuncType; +/*! + * \brief Function type in Relay. + * + * Relay support polymorphic function type. + * This can be roughly viewed as template function in C++. + * + * \sa TypeVar, TypeConstraint + */ +class FuncTypeNode : public TypeNode { + public: + /*! \brief type type of arguments */ + Array arg_types; + /*! \brief The type of return value. */ + Type ret_type; + // The following fields are used in polymorphic(template) functions + // For normal functions, the following two fields will be empty. + /*! \brief The type parameters of the function */ + Array type_params; + /*! + * \brief potential constraint the type need to obey + * \note this field is reserved for futher purposes. + */ + Array type_constraints; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("arg_types", &arg_types); + v->Visit("ret_type", &ret_type); + v->Visit("type_params", &type_params); + v->Visit("type_constraints", &type_constraints); + v->Visit("span", &span); + } + + TVM_DLL static FuncType make(Array arg_types, + Type ret_type, + Array type_params, + Array type_constraints); + + static constexpr const char* _type_key = "relay.FuncType"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); +}; + +class FuncType : public Type { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); +}; + +} // namespace tvm +#endif // TVM_IR_TYPE_H_ diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index d64d05f119bb..7191e1f02d14 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -25,6 +25,7 @@ #define TVM_RELAY_BASE_H_ #include +#include #include #include #include @@ -58,88 +59,9 @@ namespace relay { */ using IndexExpr = ::tvm::Expr; -/*! - * \brief The source name in the Span - * \sa SourceNameNode, Span - */ -class SourceName; -/*! - * \brief The name of a source fragment. - */ -class SourceNameNode : public Object { - public: - /*! \brief The source name. */ - std::string name; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } - - static constexpr const char* _type_key = "relay.SourceName"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); -}; - -/*! - * \brief The source name of a file span. - * \sa SourceNameNode, Span - */ -class SourceName : public ObjectRef { - public: - /*! \brief default constructor */ - SourceName() {} - - /*! \brief constructor from node pointer */ - explicit SourceName(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const SourceNameNode* operator->() const { - return static_cast(get()); - } - - /*! - * \brief Get an SourceName for a given operator name. - * Will raise an error if the source name has not been registered. - * \param name Name of the operator. - * \return SourceName valid throughout program lifetime. - */ - TVM_DLL static SourceName Get(const std::string& name); - - /*! \brief specify container node */ - using ContainerType = SourceNameNode; -}; - -/*! - * \brief Span information for debugging purposes - */ -class Span; -/*! - * \brief Stores locations in frontend source that generated a node. - */ -class SpanNode : public Object { - public: - /*! \brief The source name */ - SourceName source; - /*! \brief Line number */ - int lineno; - /*! \brief column offset */ - int col_offset; - // override attr visitor - void VisitAttrs(AttrVisitor* v) { - v->Visit("source", &source); - v->Visit("lineno", &lineno); - v->Visit("col_offset", &col_offset); - } - - TVM_DLL static Span make(SourceName source, int lineno, int col_offset); - - static constexpr const char* _type_key = "relay.Span"; - TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); -}; - -class Span : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); -}; +using SourceName = tvm::SourceName; +using Span = tvm::Span; +using SpanNode = tvm::SpanNode; /*! * \brief This is the base node container of all relay structures. diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 8f51ea93821d..c6a560ab9cda 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -25,8 +25,8 @@ #define TVM_RELAY_TYPE_H_ #include +#include #include -#include #include #include "base.h" @@ -36,32 +36,17 @@ namespace tvm { namespace relay { using Any = tvm::ir::Any; - -/*! \brief Base type of the Relay type hiearchy. */ -class TypeNode : public RelayNode { - public: - static constexpr const char* _type_key = "relay.Type"; - TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); -}; - -/*! - * \brief Type is the base type of relay type hiearchy. - * - * Relay's type system contains following two key concepts: - * - * - TensorType: type of certain Tensor values in the expression. - * - FunctionType: the type of the function. - * - * There are also advanced types to support generic(polymorphic types), - * which can be ignored when first reading the code base. - */ -class Type : public ObjectRef { - public: - Type() {} - explicit Type(ObjectPtr p) : ObjectRef(p) {} - - using ContainerType = TypeNode; -}; +using Kind = TypeKind; +using Type = tvm::Type; +using TypeNode = tvm::TypeNode; +using TypeVar = tvm::TypeVar; +using TypeVarNode = tvm::TypeVarNode; +using GlobalTypeVar = tvm::GlobalTypeVar; +using GlobalTypeVarNode = tvm::GlobalTypeVarNode; +using TypeConstraint = tvm::TypeConstraint; +using TypeConstraintNode = tvm::TypeConstraintNode; +using FuncType = tvm::FuncType; +using FuncTypeNode = tvm::FuncTypeNode; /*! * \brief Base of all Tensor types @@ -124,90 +109,6 @@ class TensorType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(TensorType, Type, TensorTypeNode); }; -/*! \brief Possible kinds of Type. */ -enum Kind : int { - kType = 0, - /*! \brief Template variable in shape expression. */ - kShapeVar = 1, - kBaseType = 2, - kShape = 3, - kConstraint = 4, - kAdtHandle = 5, - kTypeData = 6 -}; - -/*! - * \brief Type parameter in the function. - * This can be viewed as template parameter in c++ template function. - * - * For example, in the following pesudo code, - * the TypeVar of f is TypeVar(kind=kShapeVar, var=n). - * This function can take in a Tensor with shape=(3, 3) and - * returns a Tensor with shape=(9,) - * - * \code - * - * template - * f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)] - * - * \endcode - * \sa TypeVarNode The actual container class of TypeVar - */ -class TypeVar; -/*! \brief TypeVar container node */ -class TypeVarNode : public TypeNode { - public: - /*! \brief Name of the variable, it only acts as a hint. */ - std::string name_hint; - /*! \brief The kind of type parameter */ - Kind kind; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - v->Visit("kind", &kind); - v->Visit("span", &span); - } - - TVM_DLL static TypeVar make(std::string name, Kind kind); - - static constexpr const char* _type_key = "relay.TypeVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); -}; - -class TypeVar : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); -}; - -/*! - * \brief A global type variable that is used for defining new types or type aliases. - */ -class GlobalTypeVar; -/*! \brief GlobalTypeVar container node */ -class GlobalTypeVarNode : public TypeNode { - public: - /*! \brief Name of the variable, it only acts as a hint. */ - std::string name_hint; - /*! \brief The kind of type parameter */ - Kind kind; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - v->Visit("kind", &kind); - v->Visit("span", &span); - } - - TVM_DLL static GlobalTypeVar make(std::string name, Kind kind); - - static constexpr const char* _type_key = "relay.GlobalTypeVar"; - TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); -}; - -class GlobalTypeVar : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); -}; - /*! * \brief Type application. */ @@ -270,70 +171,6 @@ class IncompleteType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); }; -/*! - * \brief Potential Constraints in the type. - * \note This is reserved for future use. - */ -class TypeConstraint; -/*! \brief TypeConstraint container node. */ -class TypeConstraintNode : public TypeNode { - public: - static constexpr const char* _type_key = "relay.TypeConstraint"; - TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); -}; - -class TypeConstraint : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TypeConstraint, Type, TypeConstraintNode); -}; - -class FuncType; -/*! - * \brief Function type in Relay. - * - * Relay support polymorphic function type. - * This can be roughly viewed as template function in C++. - * - * \sa TypeVar, TypeConstraint - */ -class FuncTypeNode : public TypeNode { - public: - /*! \brief type type of arguments */ - tvm::Array arg_types; - /*! \brief The type of return value. */ - Type ret_type; - // The following fields are used in polymorphic(template) functions - // For normal functions, the following two fields will be empty. - /*! \brief The type parameters of the function */ - tvm::Array type_params; - /*! - * \brief potential constraint the type need to obey - * \note this field is reserved for futher purposes. - */ - tvm::Array type_constraints; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("arg_types", &arg_types); - v->Visit("ret_type", &ret_type); - v->Visit("type_params", &type_params); - v->Visit("type_constraints", &type_constraints); - v->Visit("span", &span); - } - - TVM_DLL static FuncType make(tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints); - - static constexpr const char* _type_key = "relay.FuncType"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); -}; - -class FuncType : public Type { - public: - TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); -}; - /*! * \brief The type of tuple values. */ diff --git a/src/ir/span.cc b/src/ir/span.cc new file mode 100644 index 000000000000..1d9f07951183 --- /dev/null +++ b/src/ir/span.cc @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file span.cc + * \brief The span data structure. + */ +#include +#include + +namespace tvm { + +ObjectPtr GetSourceNameNode(const std::string& name) { + // always return pointer as the reference can change as map re-allocate. + // or use another level of indirection by creating a unique_ptr + static std::unordered_map > source_map; + + auto sn = source_map.find(name); + if (sn == source_map.end()) { + ObjectPtr n = make_object(); + source_map[name] = n; + n->name = std::move(name); + return n; + } else { + return sn->second; + } +} + +SourceName SourceName::Get(const std::string& name) { + return SourceName(GetSourceNameNode(name)); +} + +TVM_REGISTER_GLOBAL("relay._make.SourceName") +.set_body_typed(SourceName::Get); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "SourceName(" << node->name << ", " << node << ")"; + }); + +TVM_REGISTER_NODE_TYPE(SourceNameNode) +.set_creator(GetSourceNameNode) +.set_global_key([](const Object* n) { + return static_cast(n)->name; + }); + +Span SpanNode::make(SourceName source, int lineno, int col_offset) { + auto n = make_object(); + n->source = std::move(source); + n->lineno = lineno; + n->col_offset = col_offset; + return Span(n); +} + +TVM_REGISTER_NODE_TYPE(SpanNode); + +TVM_REGISTER_GLOBAL("relay._make.Span") +.set_body_typed(SpanNode::make); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Span(" << node->source << ", " << node->lineno << ", " + << node->col_offset << ")"; + }); +} // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc new file mode 100644 index 000000000000..ef5f75b86a2c --- /dev/null +++ b/src/ir/type.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/ir/type.cc + * \brief Common type system AST nodes throughout the IR. + */ +#include +#include + +namespace tvm { + +TypeVar TypeVarNode::make(std::string name, TypeKind kind) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name); + n->kind = std::move(kind); + return TypeVar(n); +} + +TVM_REGISTER_NODE_TYPE(TypeVarNode); + +TVM_REGISTER_GLOBAL("relay._make.TypeVar") +.set_body_typed([](std::string name, int kind) { + return TypeVarNode::make(name, static_cast(kind)); +}); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeVar(" << node->name_hint << ", " + << node->kind << ")"; +}); + +GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) { + ObjectPtr n = make_object(); + n->name_hint = std::move(name); + n->kind = std::move(kind); + return GlobalTypeVar(n); +} + +TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); + +TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") +.set_body_typed([](std::string name, int kind) { + return GlobalTypeVarNode::make(name, static_cast(kind)); +}); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalTypeVar(" << node->name_hint << ", " + << node->kind << ")"; +}); + +FuncType FuncTypeNode::make(tvm::Array arg_types, + Type ret_type, + tvm::Array type_params, + tvm::Array type_constraints) { + ObjectPtr n = make_object(); + n->arg_types = std::move(arg_types); + n->ret_type = std::move(ret_type); + n->type_params = std::move(type_params); + n->type_constraints = std::move(type_constraints); + return FuncType(n); +} + +TVM_REGISTER_NODE_TYPE(FuncTypeNode); + +TVM_REGISTER_GLOBAL("relay._make.FuncType") +.set_body_typed(FuncTypeNode::make); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FuncType(" << node->type_params << ", " + << node->arg_types << ", " << node->ret_type << ", " + << node->type_constraints << ")"; +}); + +} // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index ca8755730d80..3f98d878d2bf 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -22,76 +22,26 @@ * \brief The core base types for Relay. */ #include +#include #include namespace tvm { namespace relay { -using tvm::IRPrinter; using namespace tvm::runtime; -ObjectPtr GetSourceNameNode(const std::string& name) { - // always return pointer as the reference can change as map re-allocate. - // or use another level of indirection by creating a unique_ptr - static std::unordered_map > source_map; - - auto sn = source_map.find(name); - if (sn == source_map.end()) { - ObjectPtr n = make_object(); - source_map[name] = n; - n->name = std::move(name); - return n; - } else { - return sn->second; - } -} - -SourceName SourceName::Get(const std::string& name) { - return SourceName(GetSourceNameNode(name)); -} - -TVM_REGISTER_API("relay._make.SourceName") -.set_body_typed(SourceName::Get); - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, tvm::IRPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "SourceName(" << node->name << ", " << node << ")"; - }); - -TVM_REGISTER_NODE_TYPE(SourceNameNode) -.set_creator(GetSourceNameNode) -.set_global_key([](const Object* n) { - return static_cast(n)->name; - }); - -Span SpanNode::make(SourceName source, int lineno, int col_offset) { - auto n = make_object(); - n->source = std::move(source); - n->lineno = lineno; - n->col_offset = col_offset; - return Span(n); -} - -TVM_REGISTER_NODE_TYPE(SpanNode); - -TVM_REGISTER_API("relay._make.Span") -.set_body_typed(SpanNode::make); - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, tvm::IRPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", " - << node->col_offset << ")"; - }); - TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_API("relay._base.set_span") .set_body_typed([](ObjectRef node_ref, Span sp) { - auto rn = node_ref.as(); + if (auto* rn = node_ref.as()) { CHECK(rn); rn->span = sp; + } else if (auto* rn = node_ref.as()) { + rn->span = sp; + } else { + LOG(FATAL) << "Expect Type or RelayNode "; + } }); } // namespace relay diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 459e8b09c94e..6199c5490d76 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -228,11 +228,6 @@ class RelayHashHandler: hash = Combine(hash, TypeHash(var_node->type_annotation)); } hash_map_[var] = hash; - // TODO(tqchen) Introduce TypeVarExpr - // const auto* ty_param = var.as(); - // if (ty_param && ty_param->kind == Kind::kShapeVar) { - // hash_map_[ty_param->var] = hash; - // } return hash; } diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 38f86a50df57..9f371dd86d01 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const { TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); CHECK(it != type_definitions.end()) - << "There is no definition of " << var->name_hint; + << "There is no definition of " << var->name_hint; return (*it).second; } diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 48f211b4006e..f1efddfbc89e 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -63,48 +63,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; }); -TypeVar TypeVarNode::make(std::string name, Kind kind) { - ObjectPtr n = make_object(); - n->name_hint = std::move(name); - n->kind = std::move(kind); - return TypeVar(n); -} - -TVM_REGISTER_NODE_TYPE(TypeVarNode); - -TVM_REGISTER_API("relay._make.TypeVar") -.set_body_typed([](std::string name, int kind) { - return TypeVarNode::make(name, static_cast(kind)); -}); - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeVarNode(" << node->name_hint << ", " - << node->kind << ")"; -}); - -GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { - ObjectPtr n = make_object(); - n->name_hint = std::move(name); - n->kind = std::move(kind); - return GlobalTypeVar(n); -} - -TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); - -TVM_REGISTER_API("relay._make.GlobalTypeVar") -.set_body_typed([](std::string name, int kind) { - return GlobalTypeVarNode::make(name, static_cast(kind)); - }); - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalTypeVarNode(" << node->name_hint << ", " - << node->kind << ")"; -}); - TypeCall TypeCallNode::make(Type func, tvm::Array args) { ObjectPtr n = make_object(); n->func = std::move(func); @@ -143,31 +101,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; }); -FuncType FuncTypeNode::make(tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints) { - ObjectPtr n = make_object(); - n->arg_types = std::move(arg_types); - n->ret_type = std::move(ret_type); - n->type_params = std::move(type_params); - n->type_constraints = std::move(type_constraints); - return FuncType(n); -} - -TVM_REGISTER_NODE_TYPE(FuncTypeNode); - -TVM_REGISTER_API("relay._make.FuncType") -.set_body_typed(FuncTypeNode::make); - -TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, IRPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FuncTypeNode(" << node->type_params << ", " - << node->arg_types << ", " << node->ret_type << ", " - << node->type_constraints << ")"; -}); - TypeRelation TypeRelationNode::make(TypeRelationFn func, Array args, int num_inputs, diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index cdc69964562d..03ad228ee588 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -38,7 +38,7 @@ TEST(Relay, SelfReference) { auto type_fx = mod->Lookup("main"); auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); - CHECK(AlphaEqual(type_fx->checked_type(), expected)); + CHECK(relay::AlphaEqual(type_fx->checked_type(), expected)); } int main(int argc, char ** argv) {