diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 1fc28c817cd4d..899db08055fc3 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -277,26 +277,13 @@ class BaseAttrsNode : public Object { TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); }; -/*! \brief Base attribute container for all attributes */ +/*! + * \brief Managed reference to BaseAttrsNode. + * \sa AttrsNode, BaseAttrsNode + */ class Attrs : public ObjectRef { public: - // normal constructor - Attrs() {} - // construct from shared ptr. - explicit Attrs(ObjectPtr n) : ObjectRef(n) {} - - /*! \return The attribute node */ - const BaseAttrsNode* operator->() const { - return ptr(); - } - /*! \brief specify container node */ - using ContainerType = BaseAttrsNode; - - private: - /*! \return the internal attribute node */ - const BaseAttrsNode* ptr() const { - return static_cast(get()); - } + TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode); }; /*! @@ -309,12 +296,7 @@ class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ Map dict; - /*! - * \brief Consruct a Attrs backed by DictAttrsNode. - * \param dict The attributes. - * \return The dict attributes. - */ - TVM_DLL static Attrs make(Map dict); + // implementations void VisitAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final; @@ -327,6 +309,23 @@ class DictAttrsNode : public BaseAttrsNode { TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); }; +/*! + * \brief Managed reference to DictAttrsNode + * \sa DictAttrsNode. + */ +class DictAttrs : public Attrs { + public: + /*! + * \brief Consruct a Attrs backed by DictAttrsNode. + * \param dict The attributes. + * \return The dict attributes. + */ + TVM_DLL explicit DictAttrs(Map dict); + + + TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); +}; // Namespace containing detail implementations namespace detail { diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index eceafec75fa12..e37374ae9b4b6 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -211,30 +211,6 @@ class GlobalVar : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); }; -/*! - * \brief Base node of all functions. - * - * We support several variants of functions throughout the stack. - * All of the functions shares the same type system(via checked_type) - * to support cross variant calls. - * - * \sa BaseFunc - */ -class BaseFuncNode : public RelayExprNode { - public: - static constexpr const char* _type_key = "BaseFunc"; - TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); -}; - -/*! - * \brief Managed reference to BaseFuncNode. - * \sa BaseFuncNode - */ -class BaseFunc : public RelayExpr { - public: - TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); -}; - // PrimExprs that are useful as runtime containers. // /*! diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h new file mode 100644 index 0000000000000..ad36c287483f4 --- /dev/null +++ b/include/tvm/ir/function.h @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/expr.h + * \brief Base expr nodes in TVM. + */ +#ifndef TVM_IR_FUNCTION_H_ +#define TVM_IR_FUNCTION_H_ + +#include +#include +#include +#include + + +namespace tvm { + +/*! + * \brief Base node of all functions. + * + * We support several variants of functions throughout the stack. + * All of the functions shares the same type system(via checked_type) + * to support cross variant calls. + * + * \sa BaseFunc + */ +class BaseFuncNode : public RelayExprNode { + public: + /*! \brief Additional attributes storing the meta-data */ + DictAttrs attrs; + + /*! + * \brief Get a function attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const BaseFunc& f) { + * Integer value = f->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + TObjectRef GetAttr(const std::string& attr_key, + TObjectRef default_value = NullValue()) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!attrs.defined()) return default_value; + auto it = attrs->dict.find(attr_key); + if (it != attrs->dict.end()) { + return Downcast((*it).second); + } else { + return default_value; + } + } + + /*! + * \brief Check whether the function has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const BaseFunc& f) { + * if (f->HasNonzeroAttr(attr::Inline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { + return GetAttr(attr_key, 0)->value != 0; + } + + static constexpr const char* _type_key = "BaseFunc"; + TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); +}; + +/*! + * \brief Managed reference to BaseFuncNode. + * \sa BaseFuncNode + */ +class BaseFunc : public RelayExpr { + public: + TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); +}; + +} // namespace tvm +#endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index ad5f4d9b8ccbd..23d1f6e5c6283 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -26,6 +26,7 @@ #include #include +#include #include #include diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 87dd5b408ca24..fe8fae5ef7885 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index f627d187fdc21..49356ac8a955f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -164,113 +164,6 @@ class Var : public Expr { TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); }; -/*! - * \brief Function (subgraph in computational graph) - */ -class Function; -/*! \brief Function container */ -class FunctionNode : public BaseFuncNode { - public: - /*! \brief Function parameters */ - tvm::Array params; - /*! - * \brief - * The expression which represents the computation of the function, - * the expression may reference the parameters, and the type of it - * or sub-expressions may reference the type variables. - */ - Expr body; - /*! \brief User annotated return type of the function. */ - Type ret_type; - /*! - * \brief Type parameters of the function. - * Enables the function to vary its type based on these. - * This corresponds to template paramaters in c++'s terminology. - * - * \note This can be usually empty for non-polymorphic functions. - */ - tvm::Array type_params; - - /*! - * \brief The attributes which store metadata about functions. - */ - tvm::Attrs attrs; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("params", ¶ms); - v->Visit("body", &body); - v->Visit("ret_type", &ret_type); - v->Visit("type_params", &type_params); - v->Visit("attrs", &attrs); - v->Visit("span", &span); - v->Visit("_checked_type_", &checked_type_); - } - - /*! - * \brief Return the derived function annotation of this expression. - * - * \return The function type annotation. - * \note The function type annotation can contain IncompleteType. - */ - TVM_DLL FuncType func_type_annotation() const; - - /*! - * \brief Check whether the function is a primitive function. - * - * \return Whether the function is primitive or not. - */ - bool IsPrimitive() const; - - /*! - * \brief Check whether the function is marked as inline. - * - * \return Whether the function should be inlined or not. - */ - bool IsMarkedInline() const; - - /*! - * \brief Check whether the function should use the TVM default compiler to build, or - * use other compilers. - * - * \return Whether the function will be compiled using the default compiler - * (e.g. those are used in the TVM stack). - */ - bool UseDefaultCompiler() const; - - TVM_DLL static Function make(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array ty_params, - tvm::Attrs attrs = Attrs()); - - /*! - * \brief Attach the function's parameters to its attributes for use in analysis. - * \return The function with its parameters attached. - */ - Function SetParams(const tvm::Map& parameters) const; - - /*! - * \brief Retrieve the function's parameters. - * - * \return The function's parameter. - */ - tvm::Map GetParams() const; - - static constexpr const char* _type_key = "relay.Function"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); -}; - -class Function : public BaseFunc { - public: - TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); -}; - - -TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key); -TVM_DLL Function FunctionSetAttr(const Function& func, - const std::string& key, - const ObjectRef& data); - /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. @@ -550,30 +443,6 @@ class TempExpr : public Expr { TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode); }; - -/*! \brief namespace of the attributes that are attached to a function. */ -namespace attr { -/*! \brief Mark the function as a primitive function. */ -constexpr const char* kPrimitive = "Primitive"; -/*! - * \brief Indicate the compiler that should be used for builing this function. - * When this is unset or set to "default", the default compilation pipeline will be used. - */ -constexpr const char* kCompiler = "Compiler"; -/*! \brief Indicate if the function is a closure. */ -constexpr const char* kClosure = "Closure"; -/*! \brief Store a Var to parameter/Constant mapping on a Function. */ -constexpr const char* kParams = "__params__"; -/*! \brief Store the unique external symbol for external compilers. */ -constexpr const char* kExternalSymbol = "ExternalSymbol"; -/*! \brief Mark if the function should be avoided being optimized. */ -constexpr const char* kSkipOptimization = "SkipOptimization"; -/*! \brief Treat the function as a composite operator. */ -constexpr const char* kComposite = "Composite"; -/*! \brief Mark the function to be inlined. */ -constexpr const char* kInline = "Inline"; -} // namespace attr - } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 68cef94837500..d1c5ca15cc2b8 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -27,16 +27,15 @@ #include #include +#include +#include +#include +#include #include #include #include -#include "./expr.h" -#include "./adt.h" -#include "./op.h" - - namespace tvm { namespace relay { diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h new file mode 100644 index 0000000000000..e644694815608 --- /dev/null +++ b/include/tvm/relay/function.h @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/function.h + * \brief Relay Function. + */ +#ifndef TVM_RELAY_FUNCTION_H_ +#define TVM_RELAY_FUNCTION_H_ + +#include +#include +#include + + +namespace tvm { +namespace relay { + +/*! + * \brief Relay Function container + * \sa Function + */ +class FunctionNode : public BaseFuncNode { + public: + /*! \brief Function parameters */ + tvm::Array params; + /*! + * \brief + * The expression which represents the computation of the function, + * the expression may reference the parameters, and the type of it + * or sub-expressions may reference the type variables. + */ + Expr body; + /*! \brief User annotated return type of the function. */ + Type ret_type; + /*! + * \brief Type parameters of the function. + * Enables the function to vary its type based on these. + * This corresponds to template paramaters in c++'s terminology. + * + * \note This can be usually empty for non-polymorphic functions. + */ + tvm::Array type_params; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("params", ¶ms); + v->Visit("body", &body); + v->Visit("ret_type", &ret_type); + v->Visit("type_params", &type_params); + v->Visit("attrs", &attrs); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + /*! + * \brief Return the derived function annotation of this expression. + * + * \return The function type annotation. + * \note The function type annotation can contain IncompleteType. + */ + TVM_DLL FuncType func_type_annotation() const; + + /*! + * \brief Check whether the function should use the TVM default compiler to build, or + * use other compilers. + * + * \return Whether the function will be compiled using the default compiler + * (e.g. those are used in the TVM stack). + */ + bool UseDefaultCompiler() const; + + static constexpr const char* _type_key = "relay.Function"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); +}; + + +/*! + * \brief Managed reference to FunctionNode. + * \sa FunctionNode + */ +class Function : public BaseFunc { + public: + /*! + * \brief Constructor + * \param params The parameters of the function. + * \param body The body of the function. + * \param ret_type The return type of the function. + * \param ty_params The type parameters. + * \param attrs Additional function attributes. + */ + TVM_DLL Function(tvm::Array params, + Expr body, + Type ret_type, + tvm::Array ty_params, + tvm::DictAttrs attrs = NullValue()); + + TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); +}; + +/*! + * \brief Create a new function that copies func, but overrides + * the attribute value key with the value. + * + * \param func The input function. + * \param attr_key The attribute key. + * \param attr_value The value attribute value. + * + * \returns The new function with updated attributes. + * + * \note This function performs copy on write optimization for func. + * If we move an uniquely referenced func into WithAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithAttr(std::move(func), "key1", value1); + * func = WithAttr(std::move(func), "key2", value2); + * + * \endcode + */ +TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value); + +/*! + * \brief namespace of the attributes that can be attached to a relay::Function. + */ +namespace attr { +/*! \brief Mark the function as a primitive function. */ +constexpr const char* kPrimitive = "Primitive"; +/*! + * \brief Indicate the compiler that should be used for builing this function. + * When this is unset or set to "default", the default compilation pipeline will be used. + */ +constexpr const char* kCompiler = "Compiler"; +/*! \brief Indicate if the function is a closure. */ +constexpr const char* kClosure = "Closure"; +/*! \brief Store a Var to parameter/Constant mapping on a Function. */ +constexpr const char* kParams = "__params__"; +/*! \brief Store the unique external symbol for external compilers. */ +constexpr const char* kExternalSymbol = "ExternalSymbol"; +/*! \brief Mark if the function should be avoided being optimized. */ +constexpr const char* kSkipOptimization = "SkipOptimization"; +/*! \brief Treat the function as a composite operator. */ +constexpr const char* kComposite = "Composite"; +/*! \brief Mark the function to be inlined. */ +constexpr const char* kInline = "Inline"; +} // namespace attr + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_FUNCTION_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 0a2c77a3af454..d5626c80a6dee 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index f122956de9c07..a718124d21164 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -24,7 +24,7 @@ from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range from .adt import Constructor, TypeData from .module import IRModule -from .attrs import Attrs, make_node +from .attrs import Attrs, DictAttrs, make_node from .container import Array, Map from . import transform diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index f30a18f6aee24..3c656fc1b332d 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -47,9 +47,7 @@ def keys(self): keys : list of str List of keys """ - fields = self.list_field_info() - for field in fields: - yield field.name + return [field.name for field in self.list_field_info()] def get_int_tuple(self, key): """Get a python int tuple of a key @@ -93,6 +91,39 @@ def get_str(self, key): def __getitem__(self, item): return self.__getattr__(item) + +@tvm._ffi.register_object +class DictAttrs(Attrs): + """Dictionary attributes. + """ + def _dict(self): + """Get internal dict""" + return _ffi_api.DictAttrsGetDict(self) + + def keys(self): + """Get list of names in the attribute. + + Returns + ------- + keys : list of str + List of keys + """ + return [k for k, _ in self.items()] + + def __getitem__(self, k): + return self._dict().__getitem__(k) + + def __contains__(self, k): + return self._dict().__contains__(k) + + def items(self): + """Get items from the map.""" + return self._dict().items() + + def __len__(self): + return self._dict().__len__() + + def make_node(type_key, **kwargs): """Make a new IR node by its type key and fields diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index feced8da05381..00ceb5bd46230 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -53,6 +53,11 @@ def checked_type(self): class BaseFunc(RelayExpr): """Base class of all functions.""" + @property + def attrs(self): + """Return the attrs member of the function. + """ + return _ffi_api.BaseFunc_Attrs(self) @tvm._ffi.register_object("relay.GlobalVar") diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 39e68b8333ffb..a3c625173f4ee 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -266,22 +266,24 @@ def __call__(self, *args): """ return Call(self, args, None, None) - def get_params(self): - return _expr.FunctionGetParams(self) + def with_attr(self, attr_key, attr_value): + """Create a new copy of the function and update the attribute - def set_params(self, params): - for key in params: - value = params[key] - if isinstance(value, NDArray): - params[key] = Constant(value) + Parameters + ---------- + attr_key : str + The attribute key to use. - return _expr.FunctionSetParams(self, params) + attr_value : Object + The new attribute value. - def set_attribute(self, name, ref): - return _expr.FunctionSetAttr(self, name, ref) + Returns + ------- + func : Function + A new copy of the function + """ + return _expr.FunctionWithAttr(self, attr_key, attr_value) - def get_attribute(self, name): - return _expr.FunctionGetAttr(self, name) @register_relay_node diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 60a543109c500..4c4c997774700 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -53,22 +53,27 @@ Array DictAttrsNode::ListFieldInfo() const { return {}; } -Attrs DictAttrsNode::make(Map dict) { +DictAttrs::DictAttrs(Map dict) { ObjectPtr n = make_object(); n->dict = std::move(dict); - return Attrs(n); + data_ = std::move(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dict; + auto* op = static_cast(node.get()); + p->stream << op->dict; }); TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); +TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict") +.set_body_typed([](DictAttrs attrs) { + return attrs->dict; +}); + using namespace tir; // Equal handler. diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 6244c7645accb..3f0d66e198283 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -23,6 +23,7 @@ */ #include #include +#include // NOTE: reverse dependency on top/tir. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. diff --git a/src/ir/function.cc b/src/ir/function.cc new file mode 100644 index 0000000000000..1be259fe693e4 --- /dev/null +++ b/src/ir/function.cc @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/ir/function.cc + * \brief The function data structure. + */ +#include +#include + +namespace tvm { + +TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs") +.set_body_typed([](BaseFunc func) { + return func->attrs; +}); +} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 04fe5d55bceb5..45f39d5ade889 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -138,7 +138,7 @@ relay::Function RunTypeCheck(const IRModule& mod, << std::endl; } func = - relay::FunctionNode::make(concat(func->params, fv), + relay::Function(concat(func->params, fv), func->body, func->ret_type, concat(func->type_params, ftv), @@ -296,7 +296,7 @@ IRModule IRModule::FromExpr( if (auto* func_node = expr.as()) { func = GetRef(func_node); } else { - func = relay::FunctionNode::make( + func = relay::Function( relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } @@ -363,7 +363,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Add") auto func = mod_copy->Lookup(gv->name_hint); mod->Add(var, Downcast(func), update); } else { - auto func = relay::FunctionNode::make({}, Downcast(val), Type(nullptr), {}); + auto func = relay::Function({}, Downcast(val), Type(nullptr), {}); mod->Add(var, func, update); } *ret = mod; diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index d0a7da9f1ba9b..8d6b10e9462b8 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -617,15 +617,13 @@ class CompileEngineImpl : public CompileEngineNode { auto src_func = it.first->source_func; CHECK(src_func.defined()); if (!src_func->UseDefaultCompiler()) { - auto compiler = FunctionGetAttr(src_func, attr::kCompiler); - const tvm::tir::StringImmNode* code_gen = compiler.as(); - CHECK(code_gen) << "No external codegen is set"; + auto code_gen = src_func->GetAttr(attr::kCompiler); + CHECK(code_gen.defined()) << "No external codegen is set"; if (ext_mods.find(code_gen->value) == ext_mods.end()) { ext_mods[code_gen->value] = IRModule({}, {}); } - auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); - const tvm::tir::StringImmNode* symbol_name = ext_symbol.as(); - CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); + auto symbol_name = src_func->GetAttr(attr::kExternalSymbol); + CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); auto gv = GlobalVar(symbol_name->value); ext_mods[code_gen->value]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); @@ -694,8 +692,9 @@ class CompileEngineImpl : public CompileEngineNode { if (!key->source_func->UseDefaultCompiler()) { auto cache_node = make_object(); const auto name_node = - FunctionGetAttr(key->source_func, attr::kExternalSymbol).as(); - CHECK(name_node != nullptr) << "External function has not been attached a name yet."; + key->source_func->GetAttr(attr::kExternalSymbol); + CHECK(name_node.defined()) + << "External function has not been attached a name yet."; cache_node->func_name = name_node->value; cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 2a88d4b7996ab..49298198187a3 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -68,8 +68,8 @@ class CSourceModuleCodegenBase { */ std::string GetExtSymbol(const Function& func) const { const auto name_node = - FunctionGetAttr(func, attr::kExternalSymbol).as(); - CHECK(name_node != nullptr) << "Fail to retrieve external symbol."; + func->GetAttr(attr::kExternalSymbol); + CHECK(name_node.defined()) << "Fail to retrieve external symbol."; std::string ext_symbol = name_node->value; return ext_symbol; } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index f28d5415449fe..032ebcd22d206 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -415,7 +415,7 @@ class GraphRuntimeCodegen } else { LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); } - if (!func->IsPrimitive()) { + if (!func->HasNonzeroAttr(attr::kPrimitive)) { LOG(FATAL) << "TVM only support calls to primitive functions " << "(i.e functions composed of fusable operator invocations)"; } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index e15b4cbef725b..bf7df5662d9be 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -516,7 +516,7 @@ class Interpreter : } if (is_dyn) { - CHECK(func->IsPrimitive()); + CHECK(func->HasNonzeroAttr(attr::kPrimitive)); out_shapes = ComputeDynamicShape(func, args); } @@ -556,7 +556,7 @@ class Interpreter : const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. - if (closure->func->IsPrimitive()) { + if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { return InvokePrimitiveOp(closure->func, args); } auto func = closure->func; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e3c8d12a6e66d..f8ead346d3f6e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -442,7 +442,7 @@ class VMFunctionCompiler : ExprFunctor { const Expr& outputs) { std::vector argument_registers; - CHECK(func->IsPrimitive()) + CHECK_NE(func->GetAttr(attr::kPrimitive, 0)->value, 0) << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); @@ -650,7 +650,7 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const FunctionNode* func_node) { - if (!func_node->IsPrimitive()) { + if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl << "Program: " << AsText(GetRef(func_node), false) << std::endl << "AST: " << GetRef(func_node); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 25a9bcd38416d..0eb6c1a0a68c8 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -86,7 +86,7 @@ struct PrimitiveInliner : ExprMutator { } if (auto func = op.as()) { - if (func->IsPrimitive()) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { tvm::Array call_args; for (auto arg : call->args) { auto new_arg = VisitExpr(arg); @@ -109,7 +109,7 @@ struct PrimitiveInliner : ExprMutator { } Expr VisitExpr_(const FunctionNode* func) { - if (func->IsPrimitive()) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { return GetRef(func); } else { return ExprMutator::VisitExpr_(func); @@ -128,7 +128,7 @@ struct PrimitiveInliner : ExprMutator { DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); - func = FunctionNode::make(func->params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 5cf66c5807b04..987fdcb1d920b 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -43,13 +43,11 @@ inline std::string GenerateName(const Function& func) { } bool IsClosure(const Function& func) { - ObjectRef res = FunctionGetAttr(func, attr::kClosure); - const tir::IntImmNode* pval = res.as(); - return pval && pval->value != 0; + return func->GetAttr(attr::kClosure, 0)->value != 0; } -Function MarkClosure(const Function& func) { - return FunctionSetAttr(func, attr::kClosure, tvm::Integer(1)); +Function MarkClosure(Function func) { + return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); } /* The goal of this class is to lift out any nested functions into top-level @@ -65,7 +63,7 @@ class LambdaLifter : public ExprMutator { Expr VisitExpr_(const LetNode* let_node) final { bool is_lambda = false; if (auto func = let_node->value.as()) { - if (!func->IsPrimitive()) { + if (!func->HasNonzeroAttr(attr::kPrimitive)) { is_lambda = true; letrec_.push_back(let_node->var); } @@ -96,7 +94,7 @@ class LambdaLifter : public ExprMutator { auto func = GetRef(func_node); // We should not transform primitive functions. - if (func->IsPrimitive()) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { return std::move(func); } @@ -151,10 +149,10 @@ class LambdaLifter : public ExprMutator { // code for the closure. Function lifted_func; if (captured_vars.size() == 0 && free_type_vars.size() == 0) { - lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params); + lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); } else { lifted_func = - FunctionNode::make(captured_vars, body, func->func_type_annotation(), free_type_vars); + Function(captured_vars, body, func->func_type_annotation(), free_type_vars); lifted_func = MarkClosure(lifted_func); } @@ -191,7 +189,7 @@ class LambdaLifter : public ExprMutator { if (auto* n = pair.second.as()) { if (!n->UseDefaultCompiler()) continue; auto func = GetRef(n); - func = FunctionNode::make(func->params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index b87877a828b66..86961ef259812 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -110,118 +110,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -Function FunctionNode::make(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array type_params, - tvm::Attrs attrs) { - ObjectPtr n = make_object(); - CHECK(params.defined()); - CHECK(type_params.defined()); - n->params = std::move(params); - n->body = std::move(body); - n->ret_type = std::move(ret_type); - n->type_params = std::move(type_params); - n->attrs = std::move(attrs); - return Function(n); -} - -FuncType FunctionNode::func_type_annotation() const { - Array param_types; - for (auto param : this->params) { - Type param_type = (param->type_annotation.defined()) ? param->type_annotation - : IncompleteType(Kind::kType); - param_types.push_back(param_type); - } - - Type ret_type = (this->ret_type.defined()) ? this->ret_type - : IncompleteType(Kind::kType); - return FuncType(param_types, ret_type, this->type_params, {}); -} - -bool FunctionNode::IsPrimitive() const { - ObjectRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); - const tir::IntImmNode* pval = res.as(); - return pval && pval->value != 0; -} - -bool FunctionNode::IsMarkedInline() const { - ObjectRef res = FunctionGetAttr(GetRef(this), attr::kInline); - const tir::IntImmNode* pval = res.as(); - return pval && pval->value != 0; -} - -Function FunctionNode::SetParams(const tvm::Map& parameters) const { - return FunctionSetAttr(GetRef(this), attr::kParams, parameters); -} - -TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams") -.set_body_typed( - [](const Function& func, const tvm::Map& parameters) { - return func->SetParams(parameters); -}); - -tvm::Map FunctionNode::GetParams() const { - auto node_ref = FunctionGetAttr(GetRef(this), attr::kParams); - return Downcast>(node_ref); -} - -TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams") -.set_body_typed([](const Function& func) { - return func->GetParams(); -}); - -bool FunctionNode::UseDefaultCompiler() const { - ObjectRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); - const tir::StringImmNode* pval = res.as(); - return pval == nullptr || pval->value == "default"; -} - -ObjectRef FunctionGetAttr(const Function& func, const std::string& key) { - if (!func->attrs.defined()) { return ObjectRef(); } - - const DictAttrsNode* dict_attrs = func->attrs.as(); - CHECK(dict_attrs); - auto it = dict_attrs->dict.find(key); - if (it != dict_attrs->dict.end()) { - return (*it).second; - } else { - return ObjectRef(); - } -} - -Function FunctionSetAttr(const Function& func, const std::string& key, const ObjectRef& data) { - const DictAttrsNode* dattrs = func->attrs.as(); - Attrs func_attrs; - if (dattrs) { - Map dict = dattrs->dict; - dict.Set(key, data); - func_attrs = DictAttrsNode::make(dict); - } else { - Map dict = {{key, data}}; - func_attrs = DictAttrsNode::make(dict); - } - - return FunctionNode::make( - func->params, - func->body, - func->ret_type, - func->type_params, - func_attrs); -} - -TVM_REGISTER_NODE_TYPE(FunctionNode); - -TVM_REGISTER_GLOBAL("relay._make.Function") -.set_body_typed(FunctionNode::make); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type - << ", " << node->body << ", " << node->type_params << ", " - << node->attrs << ")"; -}); Call CallNode::make(Expr op, Array args, Attrs attrs, Array type_args) { @@ -360,18 +248,6 @@ TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") return temp->Realize(); }); -TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") -.set_body_typed( - [](Function func, std::string name, ObjectRef ref) { - return FunctionSetAttr(func, name, ref); -}); - -TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr") -.set_body_typed( - [](Function func, std::string name) { - return FunctionGetAttr(func, name); -}); - TVM_REGISTER_GLOBAL("relay._make.Any") .set_body_typed([]() { return Any::make(); }); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c525b9eb7324e..87d5f320f883e 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -109,7 +109,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { body.same_as(op->body)) { return GetRef(op); } else { - return FunctionNode::make(params, body, ret_type, ty_params, op->attrs); + return Function(params, body, ret_type, ty_params, op->attrs); } } @@ -417,7 +417,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.size() == func->params.size()) { return expr; } - auto ret = FunctionNode::make(new_params, + auto ret = Function(new_params, new_body, func->ret_type, func->type_params, @@ -431,7 +431,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.push_back(v); } } - ret = FunctionNode::make(new_params, + ret = Function(new_params, new_body, func->ret_type, func->type_params, diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc new file mode 100644 index 0000000000000..b685a9f6eb1c1 --- /dev/null +++ b/src/relay/ir/function.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/ir/function.cc + * \brief Function in relay. + */ +#include + +namespace tvm { +namespace relay { + +Function::Function(tvm::Array params, + Expr body, + Type ret_type, + tvm::Array type_params, + DictAttrs attrs) { + ObjectPtr n = make_object(); + CHECK(params.defined()); + CHECK(type_params.defined()); + n->params = std::move(params); + n->body = std::move(body); + n->ret_type = std::move(ret_type); + n->type_params = std::move(type_params); + n->attrs = std::move(attrs); + data_ = std::move(n); +} + +FuncType FunctionNode::func_type_annotation() const { + Array param_types; + for (auto param : this->params) { + Type param_type = (param->type_annotation.defined()) ? param->type_annotation + : IncompleteType(Kind::kType); + param_types.push_back(param_type); + } + + Type ret_type = (this->ret_type.defined()) ? this->ret_type + : IncompleteType(Kind::kType); + return FuncType(param_types, ret_type, this->type_params, {}); +} + +bool FunctionNode::UseDefaultCompiler() const { + tir::StringImm val = this->GetAttr(attr::kCompiler); + return !val.defined() || val->value == "default"; +} + +Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value) { + FunctionNode* node = func.CopyOnWrite(); + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } + return func; +} + + +TVM_REGISTER_NODE_TYPE(FunctionNode); + +TVM_REGISTER_GLOBAL("relay._make.Function") +.set_body_typed([](tvm::Array params, + Expr body, + Type ret_type, + tvm::Array ty_params, + tvm::DictAttrs attrs) { + return Function(params, body, ret_type, ty_params, attrs); +}); + + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type + << ", " << node->body << ", " << node->type_params << ", " + << node->attrs << ")"; +}); + +TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr") +.set_body_typed( + [](Function func, std::string name, ObjectRef ref) { + return WithAttr(std::move(func), name, ref); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index d5cd5c91ff374..919b06604efd8 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -139,9 +139,8 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, } bool FunctionPassNode::SkipFunction(const Function& func) const { - ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); - const tir::IntImmNode* pval = skip_opt.as(); - return (pval && pval->value != 0) || (!func->UseDefaultCompiler()); + return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || + !(func->UseDefaultCompiler()); } Pass CreateFunctionPass( diff --git a/src/relay/pass/call_graph.h b/src/relay/pass/call_graph.h index 684e11a7600f0..28c3a2efe7792 100644 --- a/src/relay/pass/call_graph.h +++ b/src/relay/pass/call_graph.h @@ -30,6 +30,7 @@ #include #include +#include #include #include #include diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index d8167601cb077..598289fe3fe9b 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -78,7 +78,7 @@ Expr DeDup(const Expr& e) { for (const Var& param : op->params) { params.push_back(Fresh(param)); } - return FunctionNode::make(params, + return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index c87acd021a0b8..61ccd591b7595 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -521,13 +521,13 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } CHECK_GT(new_body.size(), 0U); if (new_body.size() == 1) { - return FunctionNode::make(params, new_body[0], Type(nullptr), + return Function(params, new_body[0], Type(nullptr), fn->type_params, fn->attrs); } else if (tuple->fields.size() == new_body.size()) { return new_expr; } else { Tuple tuple_body = TupleNode::make(new_body); - return FunctionNode::make(params, tuple_body, Type(nullptr), + return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs); } } else { diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index b274460bbcff3..978a3a6aa2070 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -111,7 +111,7 @@ class EtaExpander : public ExprMutator { Expr body = CallNode::make(cons, params, Attrs()); Type ret_type = TypeCall(cons->belong_to, type_params); - return FunctionNode::make( + return Function( Downcast>(params), body, ret_type, @@ -135,7 +135,7 @@ class EtaExpander : public ExprMutator { args.push_back(var); } - return FunctionNode::make( + return Function( args, CallNode::make(gvar, params), func->ret_type, diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index 43f28ae284522..f4842b185d4e3 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -63,7 +63,7 @@ FeatureSet DetectFeature(const Expr& expr) { DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_CONSTRUCT(Function, { - if (!op->IsPrimitive()) { + if (!op->HasNonzeroAttr(attr::kPrimitive)) { ExprVisitor::VisitExpr_(op); } }) diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index f79f3ee560d26..2b4cc32bd7906 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator { func = Downcast(expr); } else { // TODO(@jroesch): fix this - func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); + func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); } auto mod = IRModule( {}, diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index fd8e6fd0f59f5..a96d2a20ad4b2 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -853,7 +853,7 @@ class FuseMutator : private ExprMutator { // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { - if (fn_node->IsPrimitive()) { + if (fn_node->HasNonzeroAttr(attr::kPrimitive)) { return GetRef(fn_node); } else { return ExprMutator::VisitExpr_(fn_node); @@ -933,8 +933,8 @@ class FuseMutator : private ExprMutator { } visitor; visitor(body); const GroupInfo& ginfo = ginfo_[group]; - auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); - func = FunctionSetAttr(func, attr::kPrimitive, tvm::Integer(visitor.has_call)); + auto func = Function(ginfo.params, body, ret_type, {}); + func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call)); return CallNode::make(func, ginfo.arguments, Attrs()); } diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 9a03fde37e9b8..233e79faa84d4 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -255,7 +255,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { return Pair(res.forward, grad); }); - return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); + return Function(f->params, body, GradRetType(GetRef(f)), {}); } TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") @@ -384,7 +384,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { } Expr BPEmpty() { - Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleType::Empty(), {}); + Expr unitF = Function({}, TupleNode::make({}), TupleType::Empty(), {}); return RefCreateNode::make(unitF); } @@ -413,7 +413,7 @@ struct ReverseAD : ExprMutator { auto x_var = ll->Push(x); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto bpv = ll->Push(RefReadNode::make(bp)); - Expr nbp = FunctionNode::make( + Expr nbp = Function( {}, LetList::With([&](LetList* ll) { // we need a new ReverseAD visitor to avoid clobbering the bp local var @@ -457,7 +457,7 @@ struct ReverseAD : ExprMutator { orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefReadNode::make(bp)); - Expr nbp = FunctionNode::make( + Expr nbp = Function( {}, LetList::With([&](LetList* ll) { tvm::Array rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); @@ -583,7 +583,7 @@ Expr Gradient(const Expr& re, const IRModule& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret)); }); - return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); + return Function(f->params, body, GradRetType(GetRef(f)), {}); } TVM_REGISTER_GLOBAL("relay._transform.gradient") diff --git a/src/relay/pass/inline.cc b/src/relay/pass/inline.cc index 6c8caeede59ac..f6522c4bd0c49 100644 --- a/src/relay/pass/inline.cc +++ b/src/relay/pass/inline.cc @@ -83,7 +83,7 @@ class Inliner : ExprMutator { } Function Inline(const Function& func) { - return FunctionNode::make(func->params, + return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, @@ -101,7 +101,7 @@ class Inliner : ExprMutator { if (!func->body.defined()) return false; // The function must be annotated with the inline attribute. - if (!func->IsMarkedInline()) return false; + if (!func->HasNonzeroAttr(attr::kInline)) return false; // The function is not abled to be inlined if any callee under the CallGraph // of this function cannot be inlined. @@ -124,7 +124,7 @@ class Inliner : ExprMutator { const auto* fn = base_func.as(); CHECK(fn) << "Expected to work on a Relay function."; - auto func = FunctionNode::make(fn->params, + auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, @@ -198,7 +198,7 @@ IRModule Inline(const IRModule& module) { auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar()); if (const auto* fn = base_func.as()) { auto func = GetRef(fn); - if (func->IsMarkedInline()) { + if (func->HasNonzeroAttr(attr::kInline)) { CHECK_EQ(cgn->GetRefCount(), 0U) << cgn->GetNameHint() << " is marked as inline but not inlined."; cgn->CleanCallGraphEntries(); diff --git a/src/relay/pass/merge_composite.cc b/src/relay/pass/merge_composite.cc index 162bf3a2bba65..4a8c5c550b6ad 100644 --- a/src/relay/pass/merge_composite.cc +++ b/src/relay/pass/merge_composite.cc @@ -140,9 +140,10 @@ class MergeCompositeWrapper : public ExprMutator { if (call->op->IsInstance()) { Function func = Downcast(call->op); CHECK(func.defined()); - const auto name_node = FunctionGetAttr(func, attr::kComposite).as(); + const auto name_node = + func->GetAttr(attr::kComposite); // don't step into existing composite functions - if (name_node && name_node->value != "") { + if (name_node.defined() && name_node->value != "") { tvm::Array new_args; for (const auto& arg : call->args) { auto new_e = this->Mutate(arg); @@ -166,8 +167,8 @@ class MergeCompositeWrapper : public ExprMutator { if (extract.defined()) { auto free_vars = FreeVars(extract); // make the composite function - auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs()); - f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_)); + auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); + f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_)); // find the expressions associated with the free vars using the args_map // this tells us which expressions should be given as inputs to the composite function Array args; diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index c5857c26b7a95..2e048f002283b 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -820,7 +820,7 @@ class PartialEvaluator : public ExprFunctor Func VisitFuncStatic(const Function& func, const Expr& var) { CHECK(IsAtomic(var)); - if (func->IsPrimitive()) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { return ConstEvaluateFunc(func); } std::vector > free_vars; @@ -881,7 +881,7 @@ class PartialEvaluator : public ExprFunctor Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return FunctionNode::make(func->params, + return Function(func->params, LetList::With([&](LetList* ll) { std::vector pv; for (const auto& v : func->params) { diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc index d9600bd24abb5..6ea2b06b4f4ed 100644 --- a/src/relay/pass/partition_graph.cc +++ b/src/relay/pass/partition_graph.cc @@ -211,15 +211,18 @@ class Partitioner : public ExprMutator { } auto subgraph_func = - FunctionNode::make(params, input, call->checked_type_, {}, Attrs()); + Function(params, input, call->checked_type_, {}); std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); subgraph_func = - FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name)); - subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); - subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, - tvm::tir::StringImmNode::make(compiler_attrs->compiler)); - subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1)); + WithAttr(std::move(subgraph_func), attr::kExternalSymbol, tir::StringImmNode::make(name)); + subgraph_func = + WithAttr(std::move(subgraph_func), attr::kPrimitive, tvm::Integer(1)); + subgraph_func = + WithAttr(std::move(subgraph_func), attr::kCompiler, + tvm::tir::StringImmNode::make(compiler_attrs->compiler)); + subgraph_func = + WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1)); CHECK(!module_->ContainGlobalVar(name)) << "Global function " << name << " already exists"; // Create a global function and add it to the IRModule for the subgraph. @@ -277,7 +280,7 @@ class Partitioner : public ExprMutator { params.push_back(new_param); } auto body = VisitExpr(op->body); - return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs); + return Function(params, body, op->ret_type, op->type_params, op->attrs); } } @@ -351,7 +354,7 @@ class Partitioner : public ExprMutator { for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); - func = FunctionNode::make(func->params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index 2d4722bc6759f..a3ab18f1bc807 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -91,7 +91,7 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map); */ inline Expr TransformF(const std::function& func, const Expr& e) { if (const FunctionNode* f = e.as()) { - return FunctionNode::make(f->params, func(f->body), f->ret_type, f->type_params, f->attrs); + return Function(f->params, func(f->body), f->ret_type, f->type_params, f->attrs); } else { return func(e); } diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc index 4b7f15a36d4e0..84d6a0d24257e 100644 --- a/src/relay/pass/quantize/annotate.cc +++ b/src/relay/pass/quantize/annotate.cc @@ -99,7 +99,7 @@ Pass QuantizeAnnotate() { for (const auto& x : FreeVars(func)) { new_params.push_back(x); } - return FunctionNode::make(new_params, + return Function(new_params, func->body, func->ret_type, func->type_params, diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index f6f0112e3c257..faa5f3097c5d4 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -151,7 +151,7 @@ class StatsCollector : private ExprMutator { const FunctionNode* func = new_e.as(); CHECK(func) << "Input shoule be Function"; Expr new_body = TupleNode::make(std::move(profile_data_)); - return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, + return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, func->attrs); } diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index c75afd12e9dc6..ddb432611e44c 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -208,10 +208,10 @@ class Fill : ExprFunctor { Expr VisitExpr_(const FunctionNode* f, const Var& v) final { Expr e = GetRef(f); Expr ret; - if (f->IsPrimitive()) { + if (f->HasNonzeroAttr(attr::kPrimitive)) { ret = e; } else { - ret = FunctionNode::make(f->params, + ret = Function(f->params, GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), f->ret_type, f->type_params, diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 293d69667ed7a..49ca8d2ef326a 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -142,7 +142,7 @@ Function ToCPS(const Function& f, } Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { - CHECK(!op->IsPrimitive()) << "primitive func not supported yet."; + CHECK(!op->HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet."; return k(ToCPS(GetRef(op), m, cm, vm, answer)); } @@ -182,7 +182,7 @@ Function ToCPS(const Function& f, Expr reify(const MCont& k) { Var arg = VarNode::make("arg", Type()); - return FunctionNode::make({arg}, k(arg), Type(), {}, {}); + return Function({arg}, k(arg), Type(), {}, {}); } Expr reify(const MCont& k, const std::function& cont) { @@ -293,7 +293,7 @@ Function ToCPS(const Function& f, new_params.push_back(remap(v)); } new_params.push_back(k); - return FunctionNode::make(new_params, + return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return CallNode::make(k, {e}); }), answer, @@ -328,7 +328,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { Function ret = ToCPS(f, m, cm, &var, answer); auto new_type_params = ret->type_params; new_type_params.push_back(answer); - return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); + return Function(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); } Function ToCPS(const Function& f, const IRModule& m) { @@ -355,7 +355,7 @@ Function UnCPS(const Function& f) { // TODO(@M.K.): make alphaequal work on free term // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type))); auto x = VarNode::make("x", new_ret_type); - auto cont = FunctionNode::make({x}, x, new_ret_type, {}, {}); + auto cont = Function({x}, x, new_ret_type, {}, {}); tvm::Array args; for (const auto& p : new_params) { args.push_back(p); @@ -366,7 +366,7 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - return FunctionNode::make(new_params, + return Function(new_params, CallNode::make(f, args, {}, type_args), new_ret_type, new_type_params, diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 0ad43d03b60de..8aa1ac9c5a8b9 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -666,7 +666,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") ErrorReporter *err_reporter = new ErrorReporter(); auto module = IRModule({}, {}); auto dummy_fn_name = GlobalVar("test"); - module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); + module->Add(dummy_fn_name, Function({}, TupleNode::make({}), Type(), {}, {})); auto solver = std::make_shared(dummy_fn_name, module, err_reporter); auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index a94dce6fe4967..cae71889d7ade 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -82,7 +82,7 @@ TEST(Relay, BuildModule) { auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); auto c = relay::VarNode::make("c", tensor_type); auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); - auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {}); + auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 5935e367f432c..bc5e65e59b748 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -28,11 +28,11 @@ TEST(Relay, SelfReference) { using namespace tvm; auto tensor_type = relay::TensorType({}, DataType::Bool()); auto x = relay::VarNode::make("x", relay::Type()); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); + auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); CHECK(f->IsInstance()); auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, Array{ y }); - auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); + auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); auto mod = IRModule::FromExpr(fx); mod = relay::transform::InferType()(mod); auto type_fx = mod->Lookup("main"); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index e9b6a1cd424fc..d8a0bde5fa6d0 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -53,7 +53,7 @@ TEST(Relay, Sequential) { // Let expression and varaible a should be dead-code eliminated. auto z3 = relay::LetNode::make(a, c, z2); relay::Function func = - relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {}); + relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); // Get schedule auto reg = tvm::runtime::Registry::Get("relay.op._Register"); @@ -96,7 +96,7 @@ TEST(Relay, Sequential) { auto zz = relay::CallNode::make(add_op, {y1, c1}); zz = relay::CallNode::make(add_op, {zz, zz}); relay::Function expected_func = - relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); + relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); // Infer type for the expected function. auto mod1 = IRModule::FromExpr(expected_func); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 9e9c08fc36254..14f7de52a22d2 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -58,7 +58,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); auto c = relay::VarNode::make("c", tensor_type); auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); - auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {}); + auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index fbbda678b1029..92bb37367ef38 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -134,7 +134,7 @@ def test_recursive_func(): func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) - func = func.set_attribute("Compiler", tvm.tir.StringImm("a")) + func = func.with_attr("Compiler", tvm.tir.StringImm("a")) mod[sum_up] = func iarg = relay.var('i', shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index e3789988d2f33..bda590f563f07 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -78,9 +78,9 @@ def check_graph_runtime_result(): def set_external_func_attr(func, compiler, ext_symbol): - func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.set_attribute("Compiler", tvm.tir.StringImm(compiler)) - func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm(ext_symbol)) + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", tvm.tir.StringImm(compiler)) + func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol)) return func diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 0942cbb941ea9..397c35db53d94 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -307,7 +307,7 @@ def get_synthetic_lib(): gcc_input3 = relay.var('gcc_input3', shape=(10, 10)) subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2, gcc_input3], relay.copy(gcc_input0)) - subgraph0 = subgraph0.set_attribute( + subgraph0 = subgraph0.with_attr( "Primitive", tvm.tir.IntImm("int32", 1)) # Call subgraph0 @@ -320,7 +320,7 @@ def get_synthetic_lib(): gcc_input7 = relay.var('gcc_input7', shape=(10, 10)) subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6, gcc_input7], relay.copy(gcc_input4)) - subgraph1 = subgraph1.set_attribute( + subgraph1 = subgraph1.with_attr( "Primitive", tvm.tir.IntImm("int32", 1)) # Call subgraph1 diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index cc663a1614fe1..d3d0808734cd0 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -169,15 +169,16 @@ def test_function(): body = relay.Tuple(tvm.runtime.convert([])) type_params = tvm.runtime.convert([]) fn = relay.Function(params, body, ret_type, type_params) - fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value")) + fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value")) assert fn.params == params assert fn.body == body assert fn.type_params == type_params assert fn.span == None - assert fn.get_attribute("test_attribute") == "value" + assert fn.attrs["test_attribute"] == "value" str(fn) check_json_roundtrip(fn) + @pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.") def test_function_attrs(): param_names = ['a', 'b', 'c', 'd'] @@ -190,8 +191,10 @@ def test_function_attrs(): for param in params[:1]: cty = param.type_annotation tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype) - model_params[param] = tvm.nd.array(tensor) - fn = fn.set_params(model_params) + model_params[param] = relay.Constant(tvm.nd.array(tensor)) + + fn = fn.with_attr("__params__", model_params) + assert fn.params == params assert fn.body == body assert fn.type_params == type_params @@ -200,7 +203,7 @@ def test_function_attrs(): check_json_roundtrip(fn) json_str = tvm.ir.save_json(fn) fn_after = tvm.ir.load_json(json_str) - model_params_after = fn_after.get_params() + model_params_after = fn_after.attrs["__params__"] after_keys = [item[0] for item in model_params_after.items()] for key1, key2 in zip(model_params, after_keys): assert key1.name_hint == key2.name_hint @@ -296,4 +299,3 @@ def test_conv2d_attrs(): test_tuple_get_item() test_op() test_conv2d_attrs() - diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 7e34f48ec7e1a..db59256b76111 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -324,7 +324,7 @@ def test_function_attr(): p00 = relay.subtract(z00, w01) q00 = relay.multiply(p00, w02) func0 = relay.Function([x0, w00, w01, w02], q00) - func0 = func0.set_attribute("FuncName", tvm.tir.StringImm("a")) + func0 = func0.with_attr("FuncName", tvm.tir.StringImm("a")) x1 = relay.var('x1', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10)) @@ -334,7 +334,7 @@ def test_function_attr(): p10 = relay.subtract(z10, w11) q10 = relay.multiply(p10, w12) func1 = relay.Function([x1, w10, w11, w12], q10) - func1 = func1.set_attribute("FuncName", tvm.tir.StringImm("b")) + func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b")) assert not alpha_equal(func0, func1) @@ -665,7 +665,7 @@ def test_fn_attribute(): d = relay.var('d', shape=(10, 10)) add_1 = relay.add(c, d) add_1_fn = relay.Function([c, d], add_1) - add_1_fn = add_1_fn.set_attribute("TestAttribute", tvm.tir.StringImm("test")) + add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test")) add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) assert not relay.analysis.alpha_equal(add_1_fn, add_fn) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index a66022275c96a..108c91bd2bcb2 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -36,7 +36,7 @@ def expected(): z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) @@ -78,7 +78,7 @@ def expected(dshape): x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f0 = relay.Function([x], y) - f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # segment 1 x = relay.var("p0", shape=dshape) @@ -90,7 +90,7 @@ def expected(dshape): y1 = relay.add(relay.const(1, "float32"), y) y = relay.add(y, y1) f1 = relay.Function([x, w], y) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # segment 2 x = relay.var("p0", shape=dshape) @@ -100,7 +100,7 @@ def expected(dshape): padding=(1,1), channels=16) f2 = relay.Function([x, w], z2) - f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # segment 3 x = relay.var("p0", shape=dshape) @@ -112,7 +112,7 @@ def expected(dshape): channels=16) z3 = relay.add(z3, offset) f3 = relay.Function([x, w, offset], z3) - f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # compose x = relay.var("x", shape=dshape) @@ -145,7 +145,7 @@ def expected(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) - f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p1 = relay.var("p1", shape=dshape) @@ -153,7 +153,7 @@ def expected(dshape): concat = relay.concatenate((upsampled, p1), axis=1) out = relay.add(concat, relay.const(1, "float32")) f1 = relay.Function([p0, p1], out) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) @@ -184,12 +184,12 @@ def expected(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) - f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") f1 = relay.Function([p0], upsampled) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) @@ -219,12 +219,12 @@ def expected(dshape): x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f1 = relay.Function([x], y) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("p01", shape=dshape) y = relay.exp(x) f2 = relay.Function([x], y) - f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f1, [x]) @@ -258,7 +258,7 @@ def expected(dshape, dtype): p2 = relay.var('p2', shape=dshape, dtype=dtype) fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2)) - fused_gt = fused_gt.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + fused_gt = fused_gt.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) with sb.if_scope(fused_gt(x, y)): sb.ret(relay.Function([], x)) with sb.else_scope(): @@ -288,13 +288,13 @@ def expected(dim): p1 = relay.var("p1", shape=(3 * dim, dim)) matmul = relay.nn.dense(p0, p1) f0 = relay.Function([p0, p1], matmul) - f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=(1, 3 * dim)) splitted = relay.split(p01, indices_or_sections=3, axis=1) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) f1 = relay.Function([p01], out) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) X = relay.var("X", shape=(1, dim)) W = relay.var("W", shape=(3 * dim, dim)) @@ -325,13 +325,13 @@ def expected(dim): splitted = relay.split(p0, indices_or_sections=3, axis=1) out = splitted[0] f0 = relay.Function([p0], out) - f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=(1, dim)) p1 = relay.var("p1", shape=(dim, dim)) out = relay.nn.dense(p01, p1) f1 = relay.Function([p01, p1], out) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) X = relay.var("X", shape=(1, 3 * dim)) W = relay.var("W", shape=(dim, dim)) @@ -367,7 +367,7 @@ def before(x): def expected(p0): f0 = before(p0) - f1 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f1, [x]) return relay.Function([x], y) @@ -410,18 +410,18 @@ def expected(dshape): p0 = relay.var("p0", shape=dshape) concat = gen_consecutive_tuple(p0) f0 = relay.Function([p0], concat) - f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3])) pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) out = relay.add(pooled, relay.const(1, "float32")) f1 = relay.Function([p01], out) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2)) out = relay.add(p02, relay.const(1, "float32")) f2 = relay.Function([p02], out) - f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) @@ -463,36 +463,36 @@ def expected(dshape): p0 = relay.var("p0", shape=dshape) c = conv(p0) f0 = relay.Function(relay.analysis.free_vars(c), c) - f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=dshape) c = conv(p01) f1 = relay.Function(relay.analysis.free_vars(c), c) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p02 = relay.var("p02", shape=dshape) p12 = relay.var("p12", shape=dshape) concat1 = relay.concatenate((p02, p12), axis=1) f_concat1 = relay.Function([p02, p12], concat1) - f_concat1 = f_concat1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f_concat1 = f_concat1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3]) p03 = relay.var("p03", shape=dshape2) c = conv(p03) f2 = relay.Function(relay.analysis.free_vars(c), c) - f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p04 = relay.var("p04", shape=dshape2) c = conv(p04) f3 = relay.Function(relay.analysis.free_vars(c), c) - f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p05 = relay.var("p05", shape=dshape) p15 = relay.var("p15", shape=dshape) concat2 = relay.concatenate((p05, p15), axis=1) f_concat2 = relay.Function([p05, p15], concat2) - f_concat2 = f_concat2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f_concat2 = f_concat2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) c1 = relay.Call(f0, [x, relay.var("w1")]) @@ -530,7 +530,7 @@ def expected(): u = relay.transpose(y, axes=[0, 1]) w = relay.left_shift(z, u) f1 = relay.Function([x], w) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) @@ -561,7 +561,7 @@ def expected(): z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) mod = tvm.IRModule() @@ -603,7 +603,7 @@ def expected(): for i in range(max_fused_ops): y = relay.exp(y) f1 = relay.Function([x], y) - f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) z = relay.Call(f1, [x]) xx = relay.var("pp", shape=(10, 20)) @@ -611,7 +611,7 @@ def expected(): for i in range(n-max_fused_ops): yy = relay.exp(yy) f2 = relay.Function([xx], yy) - f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) zz = relay.Call(f2, [z]) return relay.Function([x], zz) diff --git a/tests/python/relay/test_pass_inline.py b/tests/python/relay/test_pass_inline.py index 7a0954b4887bc..f4943ab6851b6 100644 --- a/tests/python/relay/test_pass_inline.py +++ b/tests/python/relay/test_pass_inline.py @@ -33,7 +33,7 @@ def get_recursive_count_loop(): func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) - func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) mod[sum_up] = func iarg = relay.var('i', shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) @@ -56,7 +56,7 @@ def get_mod(): x11 = relay.var("x11", shape=(3, 5)) g11 = relay.GlobalVar("g11") fn11 = relay.Function([x11], x11) - fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) mod[g11] = fn11 x1 = relay.var("x1", shape=(3, 5)) @@ -135,7 +135,7 @@ def get_mod(): x11 = relay.var("x11", shape=(3, 5)) g11 = relay.GlobalVar("g11") fn11 = relay.Function([x11], x11) - fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) mod[g11] = fn11 x1 = relay.var("x1", shape=(3, 5)) @@ -143,7 +143,7 @@ def get_mod(): sb = relay.ScopeBuilder() sb.ret(x1 + y1 + g11(x1)) fn1 = relay.Function([x1, y1], sb.get()) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g1 = relay.GlobalVar("g1") mod[g1] = fn1 @@ -208,8 +208,8 @@ def get_mod(): x11 = relay.var("x11", shape=(3, 5)) g11 = relay.GlobalVar("g11") fn11 = relay.Function([x11], x11) - fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn11 = fn11.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a")) mod[g11] = fn11 x1 = relay.var("x1", shape=(3, 5)) @@ -217,7 +217,7 @@ def get_mod(): sb = relay.ScopeBuilder() sb.ret(x1 + y1 + g11(x1)) fn1 = relay.Function([x1, y1], sb.get()) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g1 = relay.GlobalVar("g1") mod[g1] = fn1 @@ -243,8 +243,8 @@ def expected(): mod = tvm.IRModule({}) x11 = relay.var("x11", shape=(3, 5)) fn11 = relay.Function([x11], x11) - fn11 = fn11.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn11 = fn11.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a")) x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) @@ -275,7 +275,7 @@ def get_mod(): x = relay.var('x', shape=[], dtype='int32') fn0 = relay.Function([x], x) - fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) gx = relay.GlobalVar("gx") mod[gx] = fn0 @@ -292,7 +292,7 @@ def get_mod(): func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], "int32")) - func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) mod[sum_up] = func iarg = relay.var("i", shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) @@ -313,7 +313,7 @@ def expected(): func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) - func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) mod[sum_up] = func iarg = relay.var('i', shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) @@ -340,7 +340,7 @@ def get_mod(): y = relay.var("y", shape=(2, 2)) x1 = relay.var("x1", shape=(2, 2)) fn1 = relay.Function([x1], x1) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g1 = relay.GlobalVar("g1") mod[g1] = fn1 mod["main"] = relay.Function([x, y], x + y + g1(x)) @@ -366,8 +366,8 @@ def get_mod(): y = relay.var("y", shape=(2, 2)) x1 = relay.var("x1", shape=(2, 2)) fn1 = relay.Function([x1], x1) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) g1 = relay.GlobalVar("g1") mod[g1] = fn1 mod["main"] = relay.Function([x, y], x + y + g1(x)) @@ -379,8 +379,8 @@ def expected(): y = relay.var("y", shape=(2, 2)) x1 = relay.var("x1", shape=(2, 2)) fn1 = relay.Function([x1], x1) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) mod["main"] = relay.Function([x, y], x + y + fn1(x)) return mod @@ -398,7 +398,7 @@ def get_mod(): sb = relay.ScopeBuilder() sb.ret(x1 + y1) fn1 = relay.Function([x1, y1], sb.get()) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g1 = relay.GlobalVar("g1") mod[g1] = fn1 @@ -407,7 +407,7 @@ def get_mod(): sb1 = relay.ScopeBuilder() sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -445,8 +445,8 @@ def get_mod(): sb = relay.ScopeBuilder() sb.ret(x1 + y1) fn1 = relay.Function([x1, y1], sb.get()) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) g1 = relay.GlobalVar("g1") mod[g1] = fn1 @@ -455,8 +455,8 @@ def get_mod(): sb1 = relay.ScopeBuilder() sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -477,16 +477,16 @@ def expected(): sb = relay.ScopeBuilder() sb.ret(x1 + y1) fn1 = relay.Function([x1, y1], sb.get()) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) sb1 = relay.ScopeBuilder() sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) p0 = relay.var("p0", shape=(3, 5)) p1 = relay.var("p1", shape=(3, 5)) @@ -507,9 +507,9 @@ def test_inline_globalvar_without_args(): def get_mod(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) fn2 = relay.Function([], relay.const(2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g1 = relay.GlobalVar('g1') g2 = relay.GlobalVar('g2') mod[g1] = fn1 @@ -521,9 +521,9 @@ def get_mod(): def expected(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) fn2 = relay.Function([], relay.const(2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) p = relay.var('p', 'bool') mod['main'] = relay.Function([p], relay.Call( relay.If(p, fn1, fn2), [])) @@ -538,11 +538,11 @@ def test_inline_globalvar_without_args_extern_compiler(): def get_mod(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) fn2 = relay.Function([], relay.const(2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) g1 = relay.GlobalVar('g1') g2 = relay.GlobalVar('g2') mod[g1] = fn1 @@ -554,11 +554,11 @@ def get_mod(): def expected(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.set_attribute("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) fn2 = relay.Function([], relay.const(2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.set_attribute("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) p = relay.var('p', 'bool') mod['main'] = relay.Function([p], relay.Call( relay.If(p, fn1, fn2), [])) @@ -593,7 +593,7 @@ def get_mod(): sb1 = relay.ScopeBuilder() sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -659,14 +659,14 @@ def get_mod(): x1 = relay.var("x1", shape=(3, 5)) y1 = relay.var("y1", shape=(3, 5)) fn1 = relay.Function([x1, y1], x1 + y1) - fn1 = fn1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g1 = relay.GlobalVar("g1") mod[g1] = fn1 x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) fn2 = relay.Function([x2, y2], x2 - y2) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -699,7 +699,7 @@ def get_mod(): x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) fn2 = relay.Function([x2, y2], x2 - g1(x2, y2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -728,7 +728,7 @@ def get_mod(): x0 = relay.var("x0", shape=(3, 5)) y0 = relay.var("y0", shape=(3, 5)) fn0 = relay.Function([x0, y0], x0 * y0) - fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g0 = relay.GlobalVar("g0") mod[g0] = fn0 @@ -741,7 +741,7 @@ def get_mod(): x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) fn2 = relay.Function([x2, y2], x2 - g1(x2, y2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 return mod @@ -757,7 +757,7 @@ def expected(): x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) fn2 = relay.Function([x2, y2], x2 - g1(x2, y2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -786,8 +786,8 @@ def get_mod(): x0 = relay.var("x0", shape=(3, 5)) y0 = relay.var("y0", shape=(3, 5)) fn0 = relay.Function([x0, y0], x0 * y0) - fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn0 = fn0.set_attribute("Compiler", tvm.tir.StringImm("aa")) + fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa")) g0 = relay.GlobalVar("g0") mod[g0] = fn0 @@ -800,7 +800,7 @@ def get_mod(): x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) fn2 = relay.Function([x2, y2], x2 - g1(x2, y2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 return mod @@ -810,8 +810,8 @@ def expected(): x0 = relay.var("x0", shape=(3, 5)) y0 = relay.var("y0", shape=(3, 5)) fn0 = relay.Function([x0, y0], x0 * y0) - fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - fn0 = fn0.set_attribute("Compiler", tvm.tir.StringImm("aa")) + fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa")) x1 = relay.var("x1", shape=(3, 5)) y1 = relay.var("y1", shape=(3, 5)) @@ -822,7 +822,7 @@ def expected(): x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) fn2 = relay.Function([x2, y2], x2 - g1(x2, y2)) - fn2 = fn2.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) g2 = relay.GlobalVar("g2") mod[g2] = fn2 diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index bcf61a01f1dba..63fb6b77dca4e 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -164,7 +164,7 @@ def expected(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) + add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu")) # merged function r = relay.Call(add_relu, [a, b]) @@ -229,7 +229,7 @@ def expected(): sub_node = relay.subtract(in_1, in_2) mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) - add_sub_mul = add_sub_mul.set_attribute("Composite", + add_sub_mul = add_sub_mul.with_attr("Composite", tir.StringImm("add_sub_mul")) # add_sub_mul1 function @@ -239,7 +239,7 @@ def expected(): sub_node_1 = relay.subtract(in_3, in_4) mul_node_1 = relay.multiply(add_node_1, sub_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) - add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite", + add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", tir.StringImm("add_sub_mul")) # merged function @@ -299,7 +299,7 @@ def expected(): add_node_1 = relay.add(in_1, add_node) add_node_2 = relay.add(add_node_1, add_node) add_add_add = relay.Function([in_1, in_2], add_node_2) - add_add_add = add_add_add.set_attribute("Composite", + add_add_add = add_add_add.with_attr("Composite", tir.StringImm("add_add_add")) # merged function @@ -383,7 +383,7 @@ def expected(): bias_node = relay.nn.bias_add(conv_node, in_3) r = relay.nn.relu(bias_node) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) - conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite", + conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", tir.StringImm("conv2d_bias_relu")) # add_relu function @@ -392,7 +392,7 @@ def expected(): add_node = relay.add(in_4, in_5) r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) - add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) + add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu")) # merged function conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) @@ -461,7 +461,7 @@ def after_A_priority(composite_name): out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) - merged_func = merged_func.set_attribute('Composite', + merged_func = merged_func.with_attr('Composite', tir.StringImm(composite_name)) ret = relay.Call(merged_func, [input_1, input_2]) return relay.Function([input_1, input_2], ret) @@ -527,13 +527,13 @@ def after(): y = relay.var('y') branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) - func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul")) + func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul")) call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) - func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul")) + func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul")) call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out) @@ -612,14 +612,14 @@ def after_A(): add_relu_1 = relay.add(x, y) add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) - add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu')) + add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu')) add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') y1 = relay.var('y1') add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) - add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu')) + add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu')) add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') y2 = relay.var('y2') @@ -627,7 +627,7 @@ def after_A(): sub = relay.subtract(x2, y2) add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) - add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul')) + add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul')) add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call) @@ -640,7 +640,7 @@ def after_B(): add_relu = relay.add(x, y) add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) - add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu')) + add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu')) add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_calls.append(add_relu_call) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 209376a4e94b9..11f7e971b2833 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -303,11 +303,11 @@ def expected(): add = x0 + y0 # Function that uses C compiler func = relay.Function([x0, y0], add) - func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - func = func.set_attribute("Compiler", + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) - func = func.set_attribute("ExternalSymbol", + func = func.with_attr("ExternalSymbol", tvm.tir.StringImm("ccompiler_0")) glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func @@ -318,7 +318,7 @@ def expected(): exp = relay.exp(p0) concat = relay.concatenate([log, exp], axis=0) fused_func = relay.Function([p0], concat) - fused_func = fused_func.set_attribute("Primitive", + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) @@ -390,10 +390,10 @@ def expected(): out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) func = relay.Function([data0, input0, input1], out) - func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl")) - func = func.set_attribute("ExternalSymbol", + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl")) + func = func.with_attr("ExternalSymbol", tvm.tir.StringImm("dnnl_0")) glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() @@ -516,11 +516,11 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) - func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.set_attribute("Compiler", + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_compiler")) - func0 = func0.set_attribute("ExternalSymbol", + func0 = func0.with_attr("ExternalSymbol", tvm.tir.StringImm("test_compiler_0")) gv0 = relay.GlobalVar("test_compiler_0") mod[gv0] = func0 @@ -535,11 +535,11 @@ def expected(): channels=16, padding=(1, 1)) func1 = relay.Function([data1, weight1], conv) - func1 = func1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) - func1 = func1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - func1 = func1.set_attribute("Compiler", + func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func1 = func1.with_attr("Compiler", tvm.tir.StringImm("test_compiler")) - func1 = func1.set_attribute("ExternalSymbol", + func1 = func1.with_attr("ExternalSymbol", tvm.tir.StringImm("test_compiler_1")) gv1 = relay.GlobalVar("test_compiler_1") mod[gv1] = func1 @@ -609,11 +609,11 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) - func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.set_attribute("Compiler", + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_compiler")) - func0 = func0.set_attribute("ExternalSymbol", + func0 = func0.with_attr("ExternalSymbol", tvm.tir.StringImm("test_compiler_0")) # main function diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py new file mode 100644 index 0000000000000..a2be2b7bf11f2 --- /dev/null +++ b/tests/python/unittest/test_ir_attrs.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_make_attrs(): + try: + x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx") + assert False + except tvm.error.TVMError as e: + assert str(e).find("unknown_key") != -1 + + try: + x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx") + assert False + except tvm.error.TVMError as e: + assert str(e).find("upper bound") != -1 + + x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4)) + assert x.name == "xx" + assert x.padding[0].value == 3 + assert x.padding[1].value == 4 + assert x.axis == 10 + + +def test_dict_attrs(): + dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) + assert dattr.x.value == 1 + datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) + assert dattr.name.value == "xyz" + assert isinstance(dattr, tvm.ir.DictAttrs) + assert "name" in dattr + assert dattr["x"].value == 1 + assert len(dattr) == 4 + assert len([x for x in dattr.keys()]) == 4 + assert len(dattr.items()) == 4 + + +if __name__ == "__main__": + test_make_attrs() + test_dict_attrs() diff --git a/tests/python/unittest/test_lang_reflection.py b/tests/python/unittest/test_node_reflection.py similarity index 79% rename from tests/python/unittest/test_lang_reflection.py rename to tests/python/unittest/test_node_reflection.py index 1691d7d11a7aa..a25ba0ab42f0a 100644 --- a/tests/python/unittest/test_lang_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -54,33 +54,6 @@ def test_make_node(): assert AA.value_index == A.value_index -def test_make_attrs(): - try: - x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx") - assert False - except tvm.error.TVMError as e: - assert str(e).find("unknown_key") != -1 - - try: - x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx") - assert False - except tvm.error.TVMError as e: - assert str(e).find("upper bound") != -1 - - x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4)) - assert x.name == "xx" - assert x.padding[0].value == 3 - assert x.padding[1].value == 4 - assert x.axis == 10 - - - dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) - assert dattr.x.value == 1 - datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) - assert dattr.name.value == "xyz" - - - def test_make_sum(): A = te.placeholder((2, 10), name='A') k = te.reduce_axis((0,10), "k")