From 68e15a8d6515e4acbb1faf40e62640769e93bff4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 12 Aug 2021 14:32:02 -0700 Subject: [PATCH] Relax AST (#2) Co-authored-by: ZihengJiang --- 3rdparty/cutlass | 2 +- CMakeLists.txt | 1 + include/tvm/ir/expr.h | 18 + include/tvm/node/structural_hash.h | 4 +- include/tvm/relax/expr.h | 404 ++++++++++++++++++ include/tvm/relax/type.h | 117 +++++ python/tvm/ir/module.py | 3 +- python/tvm/relax/__init__.py | 53 ++- python/tvm/relax/_ffi_api.py | 17 + .../tvm/relax/{builder.py => exec_builder.py} | 0 python/tvm/relax/expr.py | 118 +++++ python/tvm/relax/ty.py | 47 ++ python/tvm/relay/base.py | 4 +- src/relax/expr.cc | 181 ++++++++ src/relax/type.cc | 58 +++ src/relay/ir/base.cc | 2 + tests/python/relax/test_ast.py | 125 ++++++ tests/python/relax/test_type.py | 35 ++ .../relax/{test_relax_vm.py => test_vm.py} | 0 19 files changed, 1181 insertions(+), 8 deletions(-) create mode 100644 include/tvm/relax/expr.h create mode 100644 include/tvm/relax/type.h rename python/tvm/relax/{builder.py => exec_builder.py} (100%) create mode 100644 python/tvm/relax/expr.py create mode 100644 python/tvm/relax/ty.py create mode 100644 src/relax/expr.cc create mode 100644 src/relax/type.cc create mode 100644 tests/python/relax/test_ast.py create mode 100644 tests/python/relax/test_type.py rename tests/python/relax/{test_relax_vm.py => test_vm.py} (100%) diff --git a/3rdparty/cutlass b/3rdparty/cutlass index c2ee13a0fe99..a3bcc6981d5d 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit c2ee13a0fe99241b0e798ce647acf98e237f1d0c +Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f1ded3bed9a..58d8dd7e3336 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -322,6 +322,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS src/relay/qnn/*.cc ) + list(APPEND COMPILER_SRCS ${RELAY_OP_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_PASS_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS}) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8531c88465a..b145fa8b9532 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -135,6 +135,7 @@ class PrimExpr : public BaseExpr { TVM_DLL static PrimExpr FromObject_(ObjectRef ref); }; +class RelayExpr; /*! * \brief add operator * @@ -367,10 +368,27 @@ class RelayExprNode : public BaseExprNode { * This value is discarded during serialization. */ mutable Type checked_type_ = Type(nullptr); + + /*! + * \brief Stores the result of static shape analysis. + * + * \note The value will be optional if a static shape can not be inferred. + * use .shape() instead to acesss an always defined shape expression. + */ + Optional> shape_ = Optional>(); + /*! * \return The checked_type */ inline const Type& checked_type() const; + + /*! + * \return An expression which corresponds to the shape of the expression. + * + * Only valid when the expression's type is a Tensor. + */ + inline RelayExpr shape() const; + /*! * \brief Check if the inferred(checked) type of the Expr * is backed by a TTypeNode and return it. diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 8b8a403326c4..638d6865ce37 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file tvm/node/structural_equal.h + * \file tvm/node/structural_hash.h * \brief Structural hash class. */ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ @@ -174,7 +174,7 @@ class SHashReducer { /*! * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. - * \note This function indicate key could contain var defintions. + * \note This function indicates key could contain variable defintions. */ void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); } /*! diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h new file mode 100644 index 000000000000..5d973e23a70b --- /dev/null +++ b/include/tvm/relax/expr.h @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TVM_RELAX_EXPR_H_ +#define TVM_TVM_RELAX_EXPR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +using relay::Id; +using ExprNode = RelayExprNode; +using Expr = RelayExpr; + +/*! \brief A shape expression which allows users to construct a shape containing PrimExpr. + */ +class ShapeExprNode : public ExprNode { + public: + /*! The values of the shape expression. */ + Array values; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("values", &values); + v->Visit("shape_", &shape_); + v->Visit("checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const { + return equal(values, other->values) && + equal(checked_type_, other->checked_type_) && + equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(values); + hash_reduce(checked_type_); + hash_reduce(shape_); + } + + static constexpr const char* _type_key = "relax.expr.ShapeExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeExprNode, ExprNode); +}; + +class ShapeExpr : public Expr { + public: + TVM_DLL ShapeExpr(Array values); + TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode); +}; + + +/*! \brief The variable class for all Relax bindings. */ +class VarNode : public ExprNode { + public: + /*! \brief The identifier of the variable, is used for comparing stable equality across transformations. */ + Id vid; + /*! \brief The type annotation, used by binding sites and parameter declarations. */ + runtime::Optional type_annotation; + + /*! \return The name hint of the variable */ + const String& name_hint() const { return vid->name_hint; } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("type_annotation", &type_annotation); + v->Visit("span", &span); + v->Visit("shape_", &shape_); + v->Visit("checked_type_", &checked_type_); + } + + bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { + return equal(vid, other->vid) && + equal(type_annotation, other->type_annotation) && + // Do we use the analysis information in equality? + equal(checked_type_, other->checked_type_) && + equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(type_annotation); + hash_reduce(shape_); + hash_reduce(checked_type_); + } + + static constexpr const char* _type_key = "relax.expr.Var"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 2; + TVM_DECLARE_BASE_OBJECT_INFO(VarNode, ExprNode); +}; + +class Var : public Expr { + public: + TVM_DLL Var(String name_hint, + runtime::Optional> shape_annotation, + runtime::Optional type_annotation, + Span span = Span()) + : Var(Id(name_hint), shape_annotation, type_annotation, span) {} + + TVM_DLL Var(Id vid, + runtime::Optional> shape_annotation, + runtime::Optional type_annotation, + Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); +}; + +/*! \brief A sub-type of the variable node used to mark dataflow variables from + * normal visible "function local" bindings. + */ +class DataflowVarNode : public VarNode { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("vid", &vid); + v->Visit("type_annotation", &type_annotation); + v->Visit("span", &span); + v->Visit("shape_", &shape_); + v->Visit("checked_type_", &checked_type_); + } + + bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const { + return equal(vid, other->vid) && + equal(type_annotation, other->type_annotation) && + equal(shape_, other->shape_) && + equal(checked_type_, other->checked_type_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(vid); + hash_reduce(type_annotation); + hash_reduce(shape_); + hash_reduce(checked_type_); + } + + static constexpr const char* _type_key = "relax.expr.DataflowVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowVarNode, VarNode); +}; + +class DataflowVar : public Var { + public: + using Var::Var; // inherit constructors from Var + TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode); +}; + + +/*! \brief The base class of a variable binding in Relax. */ +class BindingNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; } + void SHashReduce(SHashReducer hash_reduce) const {} + + static constexpr const char* _type_key = "relax.expr.Binding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingNode, Object); +}; + +class Binding : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Binding, ObjectRef, BindingNode); +}; + + +/*! \brief Symbolic shape match, binds the variables of the LHS with the rhs. */ +class MatchShape; +class MatchShapeNode : public BindingNode { + public: + Array pattern; + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("value", &value); + } + + bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const { + return equal(pattern, other->pattern) && equal(value, other->value); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pattern); + hash_reduce(value); + } + + static constexpr const char* _type_key = "relax.expr.MatchShape"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchShapeNode, BindingNode); +}; + +class MatchShape : public Binding { + public: + TVM_DLL MatchShape(Array pattern, Expr value); + TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode); +}; + +class VarBinding; +class VarBindingNode : public BindingNode { + public: + Var var; + Expr value; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + } + + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { + return equal(var, other->var) && equal(value, other->value); + } + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(var); + hash_reduce(value); + } + static constexpr const char* _type_key = "relax.expr.VarBinding"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(VarBindingNode, BindingNode); +}; + +class VarBinding : public Binding { + public: + TVM_DLL VarBinding(Var var, Expr value); + TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode); +}; + + +class BindingBlock; + +class BindingBlockNode : public Object { + public: + Array bindings; + void VisitAttrs(AttrVisitor* v) { + v->Visit("bindings", &bindings); + } + bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + static constexpr const char* _type_key = "relax.expr.BindingBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(BindingBlockNode, Object); +}; + +class BindingBlock : public ObjectRef { + public: + TVM_DLL BindingBlock(Array bindings); + TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); +}; + + +class DataflowBlock; +class DataflowBlockNode : public BindingBlockNode { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("bindings", &bindings); + } + bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { + return equal(bindings, other->bindings); + } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); } + static constexpr const char* _type_key = "relax.expr.DataflowBlock"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockNode, BindingBlockNode); +}; + +class DataflowBlock : public BindingBlock { + public: + TVM_DLL DataflowBlock(Array bindings); + TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode); +}; + +/*! \brief A sequence of blocks followed by an expression. + * + * The order of blocks enforces scoping and ordering. + */ +class SeqExprNode : public ExprNode { + public: + Array blocks; + Expr body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("blocks", &blocks); + v->Visit("body", &body); + v->Visit("shape_", &shape_); + v->Visit("checked_type_", &checked_type_); + v->Visit("span", &span); + } + + bool SEqualReduce(const SeqExprNode* other, SEqualReducer equal) const { + return equal(blocks, other->blocks) && equal(body, other->body) && + equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(blocks); + hash_reduce(body); + hash_reduce(shape_); + hash_reduce(checked_type_); + } + + static constexpr const char* _type_key = "relax.expr.SeqExpr"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqExprNode, ExprNode); +}; + +class SeqExpr : public Expr { + public: + TVM_DLL SeqExpr(Array blocks, Expr body); + TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode); +}; + + +/*! \brief A Relax function, eventually to replace the current Relay function definition. */ +class FunctionNode : public BaseFuncNode { + public: + /*! + * \brief Optionally attach the function's name for improved printing, and debugging. + * It need to be consistent with the GlobalVar in the IRModule. + */ + runtime::Optional name; + /*! \brief The parameters to the function. */ + Array params; + /*! \brief The body of the function. */ + Expr body; + /*! \brief The return type of the function. */ + Type ret_type; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("params", ¶ms); + v->Visit("body", &body); + v->Visit("ret_type", &ret_type); + v->Visit("checked_type_", &checked_type_); + v->Visit("shape_", &shape_); + v->Visit("span", &span); + } + + bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal.DefEqual(params, other->params) && + equal(body, other->body) && + equal(ret_type, other->ret_type) && equal(checked_type_, other->checked_type_) && + equal(shape_, other->shape_); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(name); + hash_reduce(params); + hash_reduce(body); + hash_reduce(ret_type); + hash_reduce(checked_type_); + hash_reduce(shape_); + } + + static constexpr const char* _type_key = "relax.expr.Function"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode); +}; + +class Function : public Expr { + public: + TVM_DLL Function(runtime::Optional name, Array params, + Expr body, Type ret_type); + TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_TVM_RELAX_EXPR_H_ diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h new file mode 100644 index 000000000000..f073a7522855 --- /dev/null +++ b/include/tvm/relax/type.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/type.h + * \brief Relax typed AST nodes. + */ +#ifndef TVM_RELAX_TYPE_H_ +#define TVM_RELAX_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +class ShapeTypeNode : public TypeNode { + public: + + void VisitAttrs(tvm::AttrVisitor* v) { + } + + bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const { + return true; + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); } + + static constexpr const char* _type_key = "relax.ShapeType"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTypeNode, TypeNode); +}; + +class ShapeType : public Type { + public: + explicit ShapeType(); + explicit ShapeType(runtime::ObjectPtr n) : Type(n) {} + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType); + const ShapeTypeNode* operator->() const { + return static_cast(data_.get()); + } + const ShapeTypeNode* get() const { + return operator->(); + } + using ContainerType = ShapeTypeNode; +}; + + +class DynTensorTypeNode : public BaseTensorTypeNode { + public: + /*! + * \brief The rank of the tensor, use -1 to denote dynamic rank tensor. + */ + int rank; + /*! \brief The content data type, use void to denote the dtype is unknown. */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("rank", &rank); + v->Visit("dtype", &dtype); + v->Visit("span", &span); + } + + bool SEqualReduce(const DynTensorTypeNode* other, SEqualReducer equal) const { + return equal(rank, other->rank) && equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(rank); + hash_reduce(dtype); + } + + static constexpr const char* _type_key = "relax.DynTensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode); +}; + +/*! + * \brief Managed reference to DynTensorTypeNode. + * \sa DynTensorTypeNode. + */ +class DynTensorType : public Type { + public: + /*! + * \brief Constructor. + * \param shape The shape of the tensor. + * \param dtype The runtime dtype of the tensor's elements. + */ + TVM_DLL DynTensorType(int rank, DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode); +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_TYPE_H_ diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640c5..76d5bcdc2042 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -21,6 +21,7 @@ from . import _ffi_api from . import expr as _expr +from ..ir.function import BaseFunc from . import type as _ty from .base import Node @@ -76,7 +77,7 @@ def __setitem__(self, var, val): return self._add(var, val, True) def _add(self, var, val, update=True): - if isinstance(val, _expr.RelayExpr): + if isinstance(val, (_expr.RelayExpr, BaseFunc)): if isinstance(var, string_types): if _ffi_api.Module_ContainGlobalVar(self, var): var = _ffi_api.Module_GetGlobalVar(self, var) diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 497937327991..ed324cd60e4e 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -1,2 +1,51 @@ -from .vm import VirtualMachine, load_exec_from_file -from .builder import ExecBuilder +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from . import exec_builder +from . import expr +from . import ty +from . import vm + + +# Expr +Expr = expr.Expr +Span = expr.Span +SourceName = expr.SourceName +Id = expr.Id +GlobalVar = expr.GlobalVar +Var = expr.Var +DataflowVar = expr.DataflowVar +Binding = expr.Binding +MatchShape = expr.MatchShape +VarBinding = expr.VarBinding +BindingBlock = expr.BindingBlock +DataflowBlock = expr.DataflowBlock +SeqExpr = expr.SeqExpr +ShapeExpr = expr.ShapeExpr +Function = expr.Function + +# helper functions +const = expr.const + +# Type +ShapeType = ty.ShapeType +DynTensorType = ty.DynTensorType + +# VM +ExecBuilder = exec_builder.ExecBuilder +VirtualMachine = vm.VirtualMachine +load_exec_from_file = vm.load_exec_from_file diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py index 62b8f8a2e5a8..a127e1c81378 100644 --- a/python/tvm/relax/_ffi_api.py +++ b/python/tvm/relax/_ffi_api.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for Relax.""" import tvm._ffi tvm._ffi._init_api("relax", __name__) diff --git a/python/tvm/relax/builder.py b/python/tvm/relax/exec_builder.py similarity index 100% rename from python/tvm/relax/builder.py rename to python/tvm/relax/exec_builder.py diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py new file mode 100644 index 000000000000..b68aa34a869a --- /dev/null +++ b/python/tvm/relax/expr.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional, Union, Dict +import tvm._ffi +from ..ir.base import Node, Span, SourceName +from ..relay.base import Id +from ..tir import PrimExpr +from . import _ffi_api +from .. import relay + +GlobalVar = relay.GlobalVar +Expr = relay.Expr +Type = relay.Type +const = relay.const + + +@tvm._ffi.register_object("relax.expr.ShapeExpr") +class ShapeExpr(Expr): + values: List[PrimExpr] + + def __init__(self, values: List[PrimExpr]) -> None: + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values) + + +@tvm._ffi.register_object("relax.expr.Var") +class Var(Expr): + id: Id + type_annotation: Optional[Type] + + def __init__(self, name_hint: str, + shape_annotation: Optional[List[Type]] = None, + type_annotation: Optional[Type] = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, + shape_annotation, + type_annotation) + + @property + def name_hint(self): + """Get name hint of the current var.""" + name = str(self.vid.name_hint) + return name + + +@tvm._ffi.register_object("relax.expr.DataflowVar") +class DataflowVar(Var): + pass + + +@tvm._ffi.register_object("relax.expr.Binding") +class Binding(Node): + pass + + +@tvm._ffi.register_object("relax.expr.MatchShape") +class MatchShape(Binding): + pattern: List[PrimExpr] + value: Expr + + def __init__(self, pattern: List[PrimExpr], value: Expr) -> None: + self.__init_handle_by_constructor__(_ffi_api.MatchShape, pattern, value) + + +@tvm._ffi.register_object("relax.expr.VarBinding") +class VarBinding(Binding): + var: Var + value: Expr + + def __init__(self, var: Var, value: Expr) -> None: + self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value) + + +@tvm._ffi.register_object("relax.expr.BindingBlock") +class BindingBlock(Node): + bindings: List[Binding] + + def __init__(self, bindings: List[Binding]) -> None: + self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings) + + +@tvm._ffi.register_object("relax.expr.DataflowBlock") +class DataflowBlock(BindingBlock): + pass + + +@tvm._ffi.register_object("relax.expr.SeqExpr") +class SeqExpr(Expr): + blocks: List[BindingBlock] + body: Expr + + def __init__(self, blocks: List[BindingBlock], body: Expr) -> None: + self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body) + + +@tvm._ffi.register_object("relax.expr.Function") +class Function(Expr): + name: Optional[GlobalVar] + params: List[Var] + body: Expr + ret_type: Type + + def __init__(self, params: List[Var], body: Expr, + ret_type: Type, name: Optional[GlobalVar] = None) -> None: + self.__init_handle_by_constructor__(_ffi_api.Function, name, params, + body, ret_type) diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py new file mode 100644 index 000000000000..0c34d2797d9b --- /dev/null +++ b/python/tvm/relax/ty.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-import +"""The type nodes of the Relax language.""" +import tvm._ffi +from tvm.ir import Type, TensorType + +from . import _ffi_api + + +@tvm._ffi.register_object("relax.ShapeType") +class ShapeType(Type): + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.ShapeType) + + +@tvm._ffi.register_object("relax.DynTensorType") +class DynTensorType(TensorType): + """A dynamic TensorType in Relax. + + This is the type assigned to tensors with a known dtype and unknown shape. + + Parameters + ---------- + rank : Optional[int] + The rank of the Tensor + + dtype : Optional[str] + The content data type. + """ + + def __init__(self, rank=-1, dtype="float32"): + self.__init_handle_by_constructor__(_ffi_api.DynTensorType, rank, dtype) diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 8667bfb1dfdc..5f436a263d64 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -72,5 +72,5 @@ class Id(Object): Guaranteed to be stable across all passes. """ - def __init__(self): - raise RuntimeError("Cannot directly construct Id") + def __init__(self, string): + self.__init_handle_by_constructor__(_ffi_api.Id, string) diff --git a/src/relax/expr.cc b/src/relax/expr.cc new file mode 100644 index 000000000000..1b5901fe9a4a --- /dev/null +++ b/src/relax/expr.cc @@ -0,0 +1,181 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#include + +namespace tvm { +namespace relax { + +using tvm::runtime::Optional; + + +TVM_REGISTER_NODE_TYPE(ShapeExprNode); + +ShapeExpr::ShapeExpr(Array values) { + ObjectPtr n = make_object(); + n->values = std::move(values); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeExpr") +.set_body_typed([](Array values) { + return ShapeExpr(values); +}); + + +TVM_REGISTER_NODE_TYPE(VarNode); + +Var::Var(Id vid, + Optional> shape_annotation, + Optional type_annotation, + Span span) { + ObjectPtr n = make_object(); + n->vid = std::move(vid); + n->shape_ = std::move(shape_annotation); + n->type_annotation = std::move(type_annotation); + n->span = std::move(span); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.Var") +.set_body_typed([](String name_hint, + Optional> shape_annotation, + Optional type_annotation) { + return Var(name_hint, shape_annotation, type_annotation); +}); + + +TVM_REGISTER_NODE_TYPE(DataflowVarNode); + +TVM_REGISTER_GLOBAL("relax.DataflowVar") +.set_body_typed([](String name_hint, + Optional> shape_annotation, + Optional type_annotation) { + return DataflowVar(name_hint, shape_annotation, type_annotation); +}); + + +TVM_REGISTER_NODE_TYPE(BindingNode); + +TVM_REGISTER_GLOBAL("relax.Binding") +.set_body_typed([]() { + return Binding(); +}); + + +TVM_REGISTER_NODE_TYPE(MatchShapeNode); + +MatchShape::MatchShape(Array pattern, + Expr value) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->value = std::move(value); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.MatchShape") +.set_body_typed([](Array pattern, Expr value) { + return MatchShape(pattern, value); +}); + + +TVM_REGISTER_NODE_TYPE(VarBindingNode); + +VarBinding::VarBinding(Var var, + Expr value) { + ObjectPtr n = make_object(); + n->var = std::move(var); + n->value = std::move(value); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.VarBinding") +.set_body_typed([](Var var,Expr value) { + return VarBinding(var,value); +}); + + +TVM_REGISTER_NODE_TYPE(BindingBlockNode); + +BindingBlock::BindingBlock(Array bindings) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.BindingBlock") +.set_body_typed([](Array bindings) { + return BindingBlock(bindings); +}); + + +TVM_REGISTER_NODE_TYPE(DataflowBlockNode); + +DataflowBlock::DataflowBlock(Array bindings) { + ObjectPtr n = make_object(); + n->bindings = std::move(bindings); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.DataflowBlock") +.set_body_typed([](Array bindings) { + return DataflowBlock(bindings); +}); + + +TVM_REGISTER_NODE_TYPE(SeqExprNode); + +SeqExpr::SeqExpr(Array blocks, + Expr body) { + ObjectPtr n = make_object(); + n->blocks = std::move(blocks); + n->body = std::move(body); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.SeqExpr") +.set_body_typed([](Array blocks, Expr body) { + return SeqExpr(blocks, body); +}); + + +Function::Function(runtime::Optional name, + Array params, + Expr body, + Type ret_type) { + ObjectPtr n = make_object(); + n->name = std::move(name); + n->params = std::move(params); + n->body = std::move(body); + n->ret_type = std::move(ret_type); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(FunctionNode); + +TVM_REGISTER_GLOBAL("relax.Function") +.set_body_typed([](runtime::Optional name, + Array params, + Expr body, + Type ret_type) { + return Function(name, params, body, ret_type); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/type.cc b/src/relax/type.cc new file mode 100644 index 000000000000..498d6082de78 --- /dev/null +++ b/src/relax/type.cc @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/type.cc + * \brief Relax's type system AST nodes throughout the IR. + */ +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(ShapeTypeNode); + +ShapeType::ShapeType() { + ObjectPtr n = make_object(); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ShapeType") +.set_body_typed([]() { + return ShapeType(); +}); + +DynTensorType::DynTensorType(int rank, DataType dtype) { + ObjectPtr n = make_object(); + n->rank = std::move(rank); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DynTensorTypeNode); + +TVM_REGISTER_GLOBAL("relax.DynTensorType") +.set_body_typed([](int rank, DataType dtype) { + return DynTensorType(rank, dtype); +}); + + +} // namespace relax +} // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index deedd283c2ff..f9f3b31abbdd 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -39,5 +39,7 @@ Id::Id(String name_hint) { data_ = std::move(n); } +TVM_REGISTER_GLOBAL("relay.ir.Id").set_body_typed([](String name_hint) { return Id(name_hint); }); + } // namespace relay } // namespace tvm diff --git a/tests/python/relax/test_ast.py b/tests/python/relax/test_ast.py new file mode 100644 index 000000000000..2e55909f7531 --- /dev/null +++ b/tests/python/relax/test_ast.py @@ -0,0 +1,125 @@ +import tvm +from tvm import tir +from tvm import relax as rx +from tvm.ir import TensorType +import numpy as np + + +def test_var() -> None: + v0 = rx.Var("v0") + assert v0.name_hint == "v0" + assert v0.shape_ is None + assert v0.type_annotation is None + shape_anno = [54, 96] + type_anno = TensorType(shape_anno, "float32") + v1 = rx.Var("v1", shape_anno, type_anno) + assert v1.name_hint == "v1" + for s0, s1 in zip(v1.shape_, shape_anno): + assert s0 == s1 + assert v1.type_annotation == type_anno + + +def test_dataflow_var() -> None: + v0 = rx.DataflowVar("v0") + assert v0.name_hint == "v0" + assert v0.shape_ is None + assert v0.type_annotation is None + shape_anno = [54, 96] + type_anno = TensorType(shape_anno, "float16") + v1 = rx.DataflowVar("v1", shape_anno, type_anno) + assert v1.name_hint == "v1" + for s0, s1 in zip(v1.shape_, shape_anno): + assert s0 == s1 + assert v1.type_annotation == type_anno + assert isinstance(v1, rx.DataflowVar) + + +def test_match_shape() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchShape([m, n], shape) + assert b0.pattern[0] == m + assert b0.pattern[1] == n + assert b0.value == shape + + +def test_var_binding() -> None: + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b0 = rx.VarBinding(v0, val) + assert b0.var.name_hint == "v0" + assert b0.value == val + + +def test_binding_block() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchShape([m, n], shape) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.BindingBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + + +def test_dataflow_block() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + shape = rx.const([16, 8], "int32") + b0 = rx.MatchShape([m, n], shape) + + v0 = rx.Var("v0") + val = rx.const(np.random.rand(24, 56)) + b1 = rx.VarBinding(v0, val) + + block0 = rx.DataflowBlock([b0, b1]) + assert block0.bindings[0] == b0 + assert block0.bindings[1] == b1 + assert isinstance(block0, rx.DataflowBlock) + + +def test_seq_expr() -> None: + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + assert seqe.blocks[0] == blocks[0] + assert seqe.body == x + + +def test_shape_expr() -> None: + m = tir.Var("m", dtype="int32") + n = tir.Var("n", dtype="int32") + s = rx.ShapeExpr([m, n]) + assert s.values[0] == m + assert s.values[1] == n + + +def test_func(): + x = rx.Var("foo") + bindings = [rx.VarBinding(x, rx.const(1))] + blocks = [rx.BindingBlock(bindings)] + seqe = rx.SeqExpr(blocks, x) + ret_type = TensorType(None, "float32") + func = rx.Function([x], seqe, ret_type, rx.GlobalVar("func")) + assert func.params[0] == x + assert func.body == seqe + assert func.ret_type == ret_type + assert func.name.name_hint == "func" + + +if __name__ == "__main__": + test_var() + test_dataflow_var() + test_match_shape() + test_var_binding() + test_binding_block() + test_dataflow_block() + test_seq_expr() + test_shape_expr() + test_func() diff --git a/tests/python/relax/test_type.py b/tests/python/relax/test_type.py new file mode 100644 index 000000000000..6d8174aaad5b --- /dev/null +++ b/tests/python/relax/test_type.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relax as rx + +def test_shape_type(): + t0 = rx.ShapeType() + t1 = rx.ShapeType() + assert t0 == t1 + +def test_dyn_tensor_type(): + t0 = rx.DynTensorType() + assert t0.rank == -1 + t1 = rx.DynTensorType(3, "int32") + assert t1.rank == 3 + assert t1.dtype == "int32" + +if __name__ == "__main__": + test_shape_type() + test_dyn_tensor_type() diff --git a/tests/python/relax/test_relax_vm.py b/tests/python/relax/test_vm.py similarity index 100% rename from tests/python/relax/test_relax_vm.py rename to tests/python/relax/test_vm.py